Use SHA1 and not deprecated MD5 in demos.
[openssl.git] / demos / state_machine / state_machine.c
index b66d0592660d8157e723427790055ef15d334753..fef3f3e3d1fced897d5448399ac69daf7eb8dec5 100644 (file)
 #include <sys/socket.h>
 #include <netinet/in.h>
 
+/* die_unless is intended to work like assert, except that it happens
+   always, even if NDEBUG is defined. Use assert as a stopgap. */
+
+#define die_unless(x)  assert(x)
+
 typedef struct
     {
     SSL_CTX *pCtx;
@@ -111,24 +116,22 @@ SSLStateMachine *SSLStateMachine_new(const char *szCertificateFile,
     SSLStateMachine *pMachine=malloc(sizeof *pMachine);
     int n;
 
-    assert(pMachine);
+    die_unless(pMachine);
 
     pMachine->pCtx=SSL_CTX_new(SSLv23_server_method());
-    assert(pMachine->pCtx);
+    die_unless(pMachine->pCtx);
 
     n=SSL_CTX_use_certificate_file(pMachine->pCtx,szCertificateFile,
                                   SSL_FILETYPE_PEM);
-    assert(n > 0);
+    die_unless(n > 0);
 
     n=SSL_CTX_use_PrivateKey_file(pMachine->pCtx,szKeyFile,SSL_FILETYPE_PEM);
-    assert(n > 0);
+    die_unless(n > 0);
 
     pMachine->pSSL=SSL_new(pMachine->pCtx);
-    assert(pMachine->pSSL);
+    die_unless(pMachine->pSSL);
 
     pMachine->pbioRead=BIO_new(BIO_s_mem());
-    /* Set EOF to return 0 (-1 is the default) */
-    BIO_ctrl(pMachine->pbioRead,BIO_C_SET_BUF_MEM_EOF_RETURN,0,NULL);
 
     pMachine->pbioWrite=BIO_new(BIO_s_mem());
 
@@ -160,15 +163,39 @@ int SSLStateMachine_read_extract(SSLStateMachine *pMachine,
        {
        fprintf(stderr,"Doing SSL_accept\n");
        n=SSL_accept(pMachine->pSSL);
-       if(n < 0)
-           SSLStateMachine_print_error(pMachine,"SSL_accept failed");
        if(n == 0)
            fprintf(stderr,"SSL_accept returned zero\n");
-       assert(n >= 0);
+       if(n < 0)
+           {
+           int err;
+
+           if((err=SSL_get_error(pMachine->pSSL,n)) == SSL_ERROR_WANT_READ)
+               {
+               fprintf(stderr,"SSL_accept wants more data\n");
+               return 0;
+               }
+
+           SSLStateMachine_print_error(pMachine,"SSL_accept error");
+           exit(7);
+           }
        return 0;
        }
 
     n=SSL_read(pMachine->pSSL,aucBuf,nBuf);
+    if(n < 0)
+       {
+       int err=SSL_get_error(pMachine->pSSL,n);
+
+       if(err == SSL_ERROR_WANT_READ)
+           {
+           fprintf(stderr,"SSL_read wants more data\n");
+           return 0;
+           }
+
+       SSLStateMachine_print_error(pMachine,"SSL_read error");
+       exit(8);
+       }
+
     fprintf(stderr,"%d bytes of decrypted data read from state machine\n",n);
     return n;
     }
@@ -258,13 +285,15 @@ int OpenSocket(int nPort)
     return nFD;
     }
 
-void main(int argc,char **argv)
+int main(int argc,char **argv)
     {
     SSLStateMachine *pMachine;
     int nPort;
     int nFD;
     const char *szCertificateFile;
     const char *szKeyFile;
+    char rbuf[1];
+    int nrbuf=0;
 
     if(argc != 4)
        {
@@ -297,6 +326,14 @@ void main(int argc,char **argv)
        /* Select socket for input */
        FD_SET(nFD,&rfds);
 
+       /* check whether there's decrypted data */
+       if(!nrbuf)
+           nrbuf=SSLStateMachine_read_extract(pMachine,rbuf,1);
+
+       /* if there's decrypted data, check whether we can write it */
+       if(nrbuf)
+           FD_SET(1,&wfds);
+
        /* Select socket for output */
        if(SSLStateMachine_write_can_extract(pMachine))
            FD_SET(nFD,&wfds);
@@ -322,21 +359,29 @@ void main(int argc,char **argv)
            SSLStateMachine_read_inject(pMachine,buf,n);
            }
 
-       /* FIXME: we should only extract if stdout is ready */
-       n=SSLStateMachine_read_extract(pMachine,buf,n);
-       if(n < 0)
-           {
-           SSLStateMachine_print_error(pMachine,"read extract failed");
-           break;
-           }
-       assert(n >= 0);
-       if(n > 0)
+       /* stdout is ready for output (and hence we have some to send it) */
+       if(FD_ISSET(1,&wfds))
            {
-           int w;
+           assert(nrbuf == 1);
+           buf[0]=rbuf[0];
+           nrbuf=0;
 
-           w=write(1,buf,n);
-           /* FIXME: we should push back any unwritten data */
-           assert(w == n);
+           n=SSLStateMachine_read_extract(pMachine,buf+1,sizeof buf-1);
+           if(n < 0)
+               {
+               SSLStateMachine_print_error(pMachine,"read extract failed");
+               break;
+               }
+           assert(n >= 0);
+           ++n;
+           if(n > 0) /* FIXME: has to be true now */
+               {
+               int w;
+               
+               w=write(1,buf,n);
+               /* FIXME: we should push back any unwritten data */
+               assert(w == n);
+               }
            }
 
        /* Socket is ready for output (and therefore we have output to send) */
@@ -366,4 +411,6 @@ void main(int argc,char **argv)
            SSLStateMachine_write_inject(pMachine,buf,n);
            }
        }
+    /* not reached */
+    return 0;
     }