home *** CD-ROM | disk | FTP | other *** search
/ Liren Large Software Subsidy 10 / 10.iso / l / l460 / 2.ddi / DATAFUN.DI$ / CONV2.C < prev    next >
Encoding:
C/C++ Source or Header  |  1993-03-22  |  7.4 KB  |  263 lines

  1. /*********************************************************************/
  2. /*                        R C S  information                         */
  3. /*********************************************************************/
  4. #ifndef NO_RCS_INFO
  5. static char rcsid[] = "$Header: /hub/rel/stage4.win/stage/toolbox/matlab/datafun/RCS/conv2.c,v 1.3 1993/03/15 19:23:59 clay Exp $";
  6. #endif /* NO_RCS_INFO */
  7.  
  8. /* $Log: conv2.c,v $
  9.  * Revision 1.3  1993/03/15  19:23:59  clay
  10.  * Add shape flag to c-code.  See checkin message for conv2.m
  11.  *
  12.  * Revision 1.2  1992/12/07  22:49:18  martin
  13.  * The definiton of the header for mexFunction must be ifdefed
  14.  * to allow for Standard C compilers. Also, mexFunction should not
  15.  * be declared void for non Standard C compilers.
  16.  *                     Martin
  17.  *
  18.  * Revision 1.1  1992/10/06  18:41:13  loren
  19.  * Initial revision
  20.  * */
  21. /*********************************************************************/
  22.  
  23. /*
  24.  
  25.     CONV2.C    .MEX file for corresponding to CONV2.M
  26.         Implements a 2-D convolution
  27.  
  28.     Syntax    c = conv2(a,b) or c = conv2(a,b,shape) where
  29.             shape is one of 'same','full','valid'.  'full'
  30.             is the default.
  31.  
  32.     Clay M. Thompson 10-4-91
  33.     L. Shure 10-21-91 - modified to handle complex case
  34.     CMT 3-10-93       - modified to take shape parameter
  35.  
  36. */
  37.  
  38. #include <math.h>
  39. #include "mex.h"
  40.  
  41. #ifndef THINK_C
  42. #define DOUBLE double
  43. #define INT int
  44. #else
  45. #define DOUBLE short double
  46. #define INT long
  47. #endif
  48.  
  49. /* Input Arguments */
  50.  
  51. #define    A_IN    prhs[0]
  52. #define    B_IN    prhs[1]
  53. #define S_IN    prhs[2]
  54.  
  55. /* Output Arguments */
  56.  
  57. #define    C_OUT    plhs[0]
  58.  
  59. /* define constants */
  60. #define PLUS 1
  61. #define MINUS -1
  62.  
  63. /* Extract submatrix from a.  b = a(rstart:rend,cstart:cend); */
  64. void subMatrix(b, a, ma, na, rstart, rend, cstart, cend)
  65.     DOUBLE *b;    /* Result matrix (rend-rstart+1)-by-(cend-cstart+1) */
  66.     DOUBLE *a;    /* Original matrix ma-by-na */
  67.     INT ma;        /* Row size of a */
  68.     INT na;        /* Column size of a */
  69.     INT rstart,rend; /* Row range of submatrix b: rstart:rend */
  70.     INT cstart,cend; /* Column range of submatrix b: cstart:cend */
  71. {
  72.     register DOUBLE *p,*q;    /* Pointers to elements in a and b */
  73.     register INT    i,j;    /* Loop counters */
  74.     INT    mb,nb,step;
  75.  
  76.     /* Size of result array */
  77.     mb = rend - rstart + 1;
  78.     nb = cend - cstart + 1;
  79.  
  80.     /* Copy elements from subsection of a to b */
  81.     step = ma - mb;
  82.     q = b;
  83.     p = a + rstart + cstart*ma;
  84.     for (j=0;j<nb;++j) {
  85.         for (i=0;i<mb;++i) {
  86.             *(q++) = *(p++);
  87.         }
  88.         p += step;
  89.     }
  90. }
  91.  
  92. void conv2(c, a, b, ma, na, mb, nb, plusminus)
  93.     DOUBLE *c;    /* Result matrix (ma+mb-1)-by-(na+nb-1) */
  94.     DOUBLE *a;    /* Larger matrix */
  95.     DOUBLE *b;    /* Smaller matrix */
  96.     INT ma;        /* Row size of a */
  97.     INT na;        /* Column size of a */
  98.     INT mb;        /* Row size of b */
  99.     INT nb;        /* Column size of b */
  100.     INT plusminus;    /* add or subtract from result */
  101. {
  102.     register DOUBLE *p,*q;    /* Pointer to elements in 'a' and 'c' matrices */
  103.     register DOUBLE w;        /* Weight (element of 'b' matrix) */
  104.     INT mc,nc;
  105.     register INT k,l,i,j;
  106.     DOUBLE *r;                /* Pointer to elements in 'b' matrix */
  107.     
  108.     mc = ma+mb-1;
  109.     nc = na+nb-1;
  110.     
  111.     /* Perform convolution */
  112.     r = b;    
  113.     for (j=0; j<nb; ++j) {            /* For each non-zero element in b */
  114.         for (i=0; i<mb; ++i) {
  115.             w = *(r++);                /* Get weight from b matrix */
  116.             if (w != 0.0) {
  117.                 p = c + i + j*mc;    /* Start at first column of a in c. */
  118.                 for (l=0, q=a; l<na; l++) {        /* For each column of a ... */
  119.                     for (k=0; k<ma; k++) {    
  120.                         *(p++) += *(q++) * w * plusminus;    /* multiply by weight and add. */
  121.                     }
  122.                     p += mb - 1;    /* Jump to next column position of a in c */
  123.                 }
  124.             } /* end if */
  125.         } 
  126.     }
  127. }
  128.  
  129. #ifdef __STDC__
  130. void mexFunction(
  131.         INT             nlhs,
  132.         Matrix  *plhs[],
  133.         INT             nrhs,
  134.         Matrix  *prhs[]
  135.         )
  136. #else
  137. mexFunction(nlhs, plhs, nrhs, prhs)
  138. INT nlhs, nrhs;
  139. Matrix *plhs[], *prhs[];
  140. #endif
  141. {
  142.     Matrix    *tmp;
  143.     DOUBLE    *cr, *ci;
  144.     DOUBLE    *ar, *ai, *br, *bi;
  145.     DOUBLE    *p;
  146.     INT        ma,na;
  147.     INT        mb,nb;
  148.     INT        mc,nc;
  149.     INT        cplx;
  150.     INT        switched;
  151.     INT        code;
  152.     char    *shape;
  153.     INT        nshape;
  154.  
  155. #define SAME 0
  156. #define FULL 1
  157. #define VALID 2
  158.  
  159.     /* Check validity of arguments */
  160.  
  161.     if (nrhs < 2) 
  162.         mexErrMsgTxt("CONV2 requires at least two input arguments.");
  163.     if (nlhs > 1) 
  164.         mexErrMsgTxt("CONV2 only has one output argument.");
  165.     if (mxIsSparse(A_IN) || mxIsSparse(B_IN))
  166.         mexErrMsgTxt("CONV2 cannot operate on sparse matrices.");
  167.     if ((nrhs == 3) && (!mxIsString(S_IN) || (mxGetM(S_IN)*mxGetN(S_IN)<1)))
  168.         mexErrMsgTxt("'shape' must be a string.");
  169.     if (nrhs < 3)
  170.         code = FULL;
  171.     else {    /* Get shape parameter */
  172.         nshape = mxGetM(S_IN)*(mxGetN(S_IN)+1);
  173.         shape = (char *) mxCalloc(nshape,sizeof(char));
  174.         if (mxGetString(S_IN,shape,nshape)==1)
  175.             mexErrMsgTxt("Having trouble getting 'shape'.");
  176.         switch (shape[0]) {
  177.             case 's' : code = SAME; break;
  178.             case 'f' : code = FULL; break;
  179.             case 'v' : code = VALID; break;
  180.             default: mexErrMsgTxt("Unknown shape parameter.");
  181.         }
  182.     } /* end if */
  183.  
  184.     /* Get ready to call conv2 */
  185.     cplx = REAL;
  186.     if ((mxGetPi(A_IN) != 0) || (mxGetPi(B_IN) != 0))
  187.         cplx = COMPLEX;
  188.     if ((mxGetM(A_IN) == 0) || (mxGetM(B_IN) == 0)) {  /* Return empty matrix */
  189.         C_OUT = mxCreateFull(0,0,cplx);
  190.     } else {                                /* Compute result */
  191.  
  192.         /* Create temporary matrix to hold full convolution */
  193.         tmp = mxCreateFull(mxGetM(A_IN)+mxGetM(B_IN)-1, 
  194.                                 mxGetN(A_IN)+mxGetN(B_IN)-1, cplx);
  195.  
  196.         /* Assign pointers to various arguments */
  197.  
  198.         cr = mxGetPr(tmp);
  199.         if (cplx)
  200.             ci = mxGetPi(tmp);
  201.  
  202.         if (mxGetM(A_IN) * mxGetN(A_IN) > mxGetM(B_IN) * mxGetN(B_IN)) {
  203.             ar = mxGetPr(A_IN); ma = mxGetM(A_IN); na = mxGetN(A_IN);
  204.             br = mxGetPr(B_IN); mb = mxGetM(B_IN); nb = mxGetN(B_IN);
  205.             if (cplx) {
  206.                 ai = mxGetPi(A_IN); bi = mxGetPi(B_IN);
  207.             }
  208.             switched = 0;
  209.         } else {
  210.             ar = mxGetPr(B_IN); ma = mxGetM(B_IN); na = mxGetN(B_IN);
  211.             br = mxGetPr(A_IN); mb = mxGetM(A_IN); nb = mxGetN(A_IN);
  212.             if (cplx) {
  213.                 ai = mxGetPi(B_IN); bi = mxGetPi(A_IN);
  214.             }
  215.             switched = 1;
  216.         }        
  217.  
  218.         /* Call subroutine to perform actual calculations */
  219.  
  220.         conv2(cr,ar,br,ma,na,mb,nb,PLUS);
  221.         if (cplx) {
  222.             conv2(cr,ai,bi,ma,na,mb,nb,MINUS);
  223.             conv2(ci,ar,bi,ma,na,mb,nb,PLUS);
  224.             conv2(ci,ai,br,ma,na,mb,nb,PLUS);
  225.         }
  226.         
  227.         /* Now extract result and create return argument */
  228.         switch (code) {
  229.             case FULL: 
  230.                 C_OUT = tmp;
  231.                 break;
  232.             case SAME: /* Return center section that is the same size as A */
  233.                 if (switched==1) {
  234.                     mc = mb; nc = nb; mb = ma; nb = na;
  235.                 } else {
  236.                     mc = ma; nc = na;
  237.                 }
  238.                 C_OUT = mxCreateFull(mc, nc, cplx);
  239.                 subMatrix(mxGetPr(C_OUT),mxGetPr(tmp),
  240.                     mxGetM(tmp),mxGetN(tmp),
  241.                     (mb-1)/2,(mb-1)/2+mc-1,(nb-1)/2,(nb-1)/2+nc-1);
  242.                 if (cplx) subMatrix(mxGetPi(C_OUT),mxGetPi(tmp),
  243.                     mxGetM(tmp),mxGetN(tmp),
  244.                     (mb-1)/2,(mb-1)/2+mc-1,(nb-1)/2,(nb-1)/2+nc-1);
  245.                 mxFreeMatrix(tmp);
  246.                 break;
  247.             case VALID:    /* Return center section that is computed without edges. */
  248.                 mc = ma-mb+1; nc = na-nb+1;
  249.                 if ((mc < 0) || (nc < 0)) { /* Catch possible null matrix */
  250.                     C_OUT = mxCreateFull(0,0,cplx);
  251.                     return;
  252.                 }
  253.                 C_OUT = mxCreateFull(mc, nc, cplx);
  254.                 subMatrix(mxGetPr(C_OUT),mxGetPr(tmp),
  255.                     mxGetM(tmp),mxGetN(tmp),mb-1,mb+mc-2,nb-1,nb+nc-2);
  256.                 if (cplx) subMatrix(mxGetPi(C_OUT),mxGetPi(tmp),
  257.                     mxGetM(tmp),mxGetN(tmp),mb-1,mb+mc-2,nb-1,nb+nc-2);
  258.                 mxFreeMatrix(tmp);
  259.                 break;
  260.         }
  261.     }
  262. }
  263.