/***************************************
  WWWOFFLE - World Wide Web Offline Explorer - Version 2.9j.

  Functions for file input and output using gnutls.
  ******************/ /******************
  Written by Andrew M. Bishop

  This file Copyright 2005-2016 Andrew M. Bishop
  It may be distributed under the GNU Public License, version 2, or
  any higher version.  See section COPYING of the GNU Public license
  for conditions under which this file may be redistributed.
  ***************************************/


#include "autoconfig.h"

#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#if TIME_WITH_SYS_TIME
# include <sys/time.h>
# include <time.h>
#else
# if HAVE_SYS_TIME_H
#  include <sys/time.h>
# else
#  include <time.h>
# endif
#endif

#include <errno.h>

#include <setjmp.h>
#include <signal.h>

#if USE_GNUTLS
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#endif

#include "io.h"
#include "iopriv.h"
#include "errors.h"


#if USE_GNUTLS

#include "certificates.h"


/*+ A longjump context for write timeouts. +*/
static jmp_buf write_jmp_env;


/* Local functions */

static int pull_with_timeout(gnutls_transport_ptr_t ptr,       void* data, size_t size);
static int push_with_timeout(gnutls_transport_ptr_t ptr, const void* data, size_t size);
static int handshake_sni_callback(gnutls_session_t session);

static int set_credentials(io_gnutls *context);

static void sigalarm(int signum);

static ssize_t write_all(gnutls_session_t session,const char *data,size_t n);

static void set_gnutls_error(int err,gnutls_session_t session);


/*++++++++++++++++++++++++++++++++++++++
  Initialise the gnutls context information.

  io_gnutls *io_init_gnutls Returns a new gnutls io context.

  int fd The file descriptor for the session.

  const char *host The name of the server to serve as or NULL for a client.

  int type A flag set to 0 for client connection, 1 for built-in server or 2 for a fake server.

  unsigned timeout_r The read timeout or 0 for none.

  unsigned timeout_w The write timeout or 0 for none.
  ++++++++++++++++++++++++++++++++++++++*/

io_gnutls *io_init_gnutls(int fd,const char *host,int type,unsigned timeout_r,unsigned timeout_w)
{
 io_gnutls *context=(io_gnutls*)calloc(1,sizeof(io_gnutls));

 /* Initialise the gnutls session. */

 if(type)
    io_errno=gnutls_init(&context->session,GNUTLS_SERVER);
 else
    io_errno=gnutls_init(&context->session,GNUTLS_CLIENT);

 if(io_errno!=GNUTLS_E_SUCCESS)
   {
    set_gnutls_error(io_errno,context->session);

    PrintMessage(Warning,"GNUTLS Failed to initialise session [%s].",io_strerror);

    free(context);
    return(NULL);
   }

 io_errno=gnutls_set_default_priority(context->session);

 if(io_errno!=GNUTLS_E_SUCCESS)
   {
    set_gnutls_error(io_errno,context->session);
    gnutls_deinit(context->session);

    PrintMessage(Warning,"GNUTLS Failed to set session priority [%s].",io_strerror);

    free(context);
    return(NULL);
   }

 /* Set the server credentials (in callback for server mode, now for client) */

 if(type)
   {
    context->type=type;
    context->host=host;

    gnutls_handshake_set_post_client_hello_function(context->session,
                                                    (gnutls_handshake_post_client_hello_func)handshake_sni_callback);
   }
 else /* if(type==0) */
   {
    gnutls_server_name_set(context->session,GNUTLS_NAME_DNS,host,strlen(host));

    context->cred=GetClientCredentials();

    if(set_credentials(context))
      {
       free(context);
       return(NULL);
      }
   }

 /* Store the file descriptor and timeout and set the push and pull functions. */

 context->fd=fd;

 context->r_timeout=timeout_r;
 context->w_timeout=timeout_w;

 gnutls_transport_set_ptr(context->session,context);

 gnutls_transport_set_pull_function(context->session,(gnutls_pull_func)pull_with_timeout);
 gnutls_transport_set_push_function(context->session,(gnutls_push_func)push_with_timeout);

 /* Handshake the session on the socket */

 do
   {
    io_errno=gnutls_handshake(context->session);
   }
 while(io_errno!=GNUTLS_E_SUCCESS && !gnutls_error_is_fatal(io_errno));

 if(io_errno!=GNUTLS_E_SUCCESS)
   {
    set_gnutls_error(io_errno,context->session);
    gnutls_bye(context->session,GNUTLS_SHUT_WR);
    gnutls_deinit(context->session);

    PrintMessage(Warning,"GNUTLS handshake has failed [%s].",io_strerror);

    free(context);
    return(NULL);
   }

 /* Save the server credentials */

 if(type==0)
    PutRealCertificate(context->session,host);

 return(context);
}


/*++++++++++++++++++++++++++++++++++++++
  A function to be called by gnutles to read data from the socket.

  int pull_with_timeout Return the number of bytes read or a negative number for an error.

  gnutls_transport_ptr_t ptr The gnutls per-session transport pointer.

  void* data The buffer to fill with data read from the sockeet.

  size_t size The amount of data to read.

  This is necessary because the SNI callback only passes the gnutls session but we
  need the WWWOFFLE IO context so we have to put that into the gnutls_transport_ptr.
  Therefore the timeout function is moved into here so that it applies to all gnutls
  data transfers.
  ++++++++++++++++++++++++++++++++++++++*/

static int pull_with_timeout(gnutls_transport_ptr_t ptr, void* data, size_t size)
{
 io_gnutls *context=(io_gnutls*)ptr;
 int n;

 if(context->r_timeout)
   {
    fd_set readfd;
    struct timeval tv;

    while(1)
      {
       FD_ZERO(&readfd);

       FD_SET(context->fd,&readfd);

       tv.tv_sec=context->r_timeout;
       tv.tv_usec=0;

       n=select(context->fd+1,&readfd,NULL,NULL,&tv);

       if(n>0)
          break;
       else if(n==0)
          return(0);
       else if(errno!=EINTR)
          return(-1);
      }
   }

 n=read(context->fd,data,size);

 return(n);
}


/*++++++++++++++++++++++++++++++++++++++
  The signal handler for the alarm signal to timeout the socket write.

  int signum The signal number.
  ++++++++++++++++++++++++++++++++++++++*/

static void sigalarm(/*@unused@*/ int signum)
{
 longjmp(write_jmp_env,1);
}


/*++++++++++++++++++++++++++++++++++++++
  A function to be called by gnutls to write data to the socket.

  int push_with_timeout Return the number of bytes written or a negative number for an error.

  gnutls_transport_ptr_t ptr The gnutls per-session transport pointer.

  const void* data The buffer of data to write to the sockeet.

  size_t size The amount of data to write.

  This is necessary because the SNI callback only passes the gnutls session but we
  need the WWWOFFLE IO context so we have to put that into the gnutls_transport_ptr.
  Therefore the timeout function is moved into here so that it applies to all gnutls
  data transfers.
  ++++++++++++++++++++++++++++++++++++++*/

static int push_with_timeout(gnutls_transport_ptr_t ptr, const void* data, size_t size)
{
 io_gnutls *context=(io_gnutls*)ptr;
 struct sigaction action;
 int n;

 start:

 if(!context->w_timeout)
   {
    n=write(context->fd,data,size);

    return(n);
   }

 action.sa_handler = sigalarm;
 sigemptyset(&action.sa_mask);
 action.sa_flags = 0;
 if(sigaction(SIGALRM, &action, NULL) != 0)
   {
    PrintMessage(Warning, "Failed to set SIGALRM; cancelling timeout for writing.");
    context->w_timeout=0;
    goto start;
   }

 alarm(context->w_timeout);

 if(setjmp(write_jmp_env))
   {
    n=-1;
    errno=ETIMEDOUT;
   }
 else
    n=write(context->fd,data,size);

 alarm(0);
 action.sa_handler = SIG_IGN;
 sigemptyset(&action.sa_mask);
 action.sa_flags = 0;
 if(sigaction(SIGALRM, &action, NULL) != 0)
    PrintMessage(Warning, "Failed to clear SIGALRM.");

 return(n);
}


/*++++++++++++++++++++++++++++++++++++++
  A callback to handle the TLS SNI extension.

  int handshake_sni_callback Returns 0 if OK, something else otherwise.

  gnutls_session_t session The GNUTLS session information.
  ++++++++++++++++++++++++++++++++++++++*/

static int handshake_sni_callback(gnutls_session_t session)
{
 int err;
 size_t sni_len=40;
 unsigned int sni_type;
 char *sni_name=NULL;
 io_gnutls *context;

 context=gnutls_transport_get_ptr(session);

 /* Work out the server requested (using SNI) */

 sni_name=(char*)malloc(sni_len);

 err=gnutls_server_name_get(context->session,sni_name,&sni_len,&sni_type,0);

 if(err==GNUTLS_E_SHORT_MEMORY_BUFFER)
   {
    sni_name=(char*)realloc((void*)sni_name,sni_len);

    err=gnutls_server_name_get(context->session,sni_name,&sni_len,&sni_type,0);
   }

 if(err==GNUTLS_E_SUCCESS && sni_type==GNUTLS_NAME_DNS)
    PrintMessage(Inform,"GNUTLS SNI server name was '%s'.",sni_name);
 else
   {
    if(err==GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
       PrintMessage(Inform,"GNUTLS No SNI server name found, using '%s'.",context->host);
    else if(err==GNUTLS_E_SUCCESS && sni_type!=GNUTLS_NAME_DNS)
       PrintMessage(Warning,"GNUTLS Requested SNI server name was not a DNS name [type=%d], using '%s'.",sni_type,context->host);
    else
       PrintMessage(Warning,"GNUTLS Request for SNI server name returned an error [%s], using '%s'.",gnutls_strerror(err),context->host);

    free(sni_name);
    sni_name=NULL;
   }

 /* Set the server credentials */

 if(context->type==1)
   {
    if(sni_name)
      {
       context->cred=GetServerCredentials(sni_name);
       free(sni_name);
      }
    else
       context->cred=GetServerCredentials(context->host);
   }
 else /* if(context->type==2) */
   {
    if(sni_name)
      {
       context->cred=GetFakeCredentials(sni_name);
       free(sni_name);
      }
    else
       context->cred=GetFakeCredentials(context->host);
   }

 return(set_credentials(context));
}


/*++++++++++++++++++++++++++++++++++++++
  Set the GNUTLS credentials.

  int set_credentials Return 0 if OK or something else otherwise.

  io_gnutls *context The gnutls context information.

  This is only in a separate function because it needs to be called both
  from the SNI callback function and from the main io_init_gnutls function.
  ++++++++++++++++++++++++++++++++++++++*/

static int set_credentials(io_gnutls *context)
{
 if(!context->cred)
   {
    if(io_strerror)
       free(io_strerror);
    io_strerror=(char*)malloc(40);

    strcpy(io_strerror,"IO(gnutls): Failed to get credentials");

    PrintMessage(Warning,"GNUTLS Failed to get server credentials [%s].",io_strerror);

    return(-1);
   }

 io_errno=gnutls_credentials_set(context->session,GNUTLS_CRD_CERTIFICATE,context->cred);

 if(io_errno!=GNUTLS_E_SUCCESS)
   {
    set_gnutls_error(io_errno,context->session);

    PrintMessage(Warning,"GNUTLS Failed to set session credentials [%s].",io_strerror);

    return(-1);
   }

 return(0);
}


/*++++++++++++++++++++++++++++++++++++++
  Finalise the gnutls data stream.

  int io_finish_gnutls Returns 0 on completion, negative if error.

  io_gnutls *context The gnutls context information.
  ++++++++++++++++++++++++++++++++++++++*/

int io_finish_gnutls(io_gnutls *context)
{
 gnutls_bye(context->session,GNUTLS_SHUT_WR);

 gnutls_deinit(context->session);

 FreeCredentials(context->cred);

 free(context);

 return(0);
}


/*++++++++++++++++++++++++++++++++++++++
  Read data from a gnutls session and buffer it with a timeout.

  ssize_t io_gnutls_read_with_timeout Returns the number of bytes read.

  io_gnutls *context The gnutls context information.

  io_buffer *out The IO buffer to output the data.

  unsigned timeout The maximum time to wait for data to be read (or 0 for no timeout).
  ++++++++++++++++++++++++++++++++++++++*/

ssize_t io_gnutls_read_with_timeout(io_gnutls *context,io_buffer *out,unsigned timeout)
{
 int n;

 context->r_timeout=timeout;

 do
   {
    n=gnutls_record_recv(context->session,out->data+out->length,out->size-out->length);
   }
 while(n==GNUTLS_E_INTERRUPTED || n==GNUTLS_E_AGAIN);

 if(n==GNUTLS_E_REHANDSHAKE)
    gnutls_alert_send(context->session,GNUTLS_AL_WARNING,GNUTLS_A_NO_RENEGOTIATION);

 if(n==GNUTLS_E_UNEXPECTED_PACKET_LENGTH ||
    n==GNUTLS_E_PREMATURE_TERMINATION ||
    n==GNUTLS_E_INVALID_SESSION) /* Seems to happen at the end of the data. */
    n=0;

 if(n>0)
    out->length+=n;

 if(n<0)
    set_gnutls_error(n,context->session);

 return(n);
}


/*++++++++++++++++++++++++++++++++++++++
  Write some data to a gnutls session from a buffer with a timeout.

  ssize_t io_gnutls_write_with_timeout Returns the number of bytes written or negative on error.

  io_gnutls *context The gnutls context information.

  io_buffer *in The IO buffer with the input data.

  unsigned timeout The maximum time to wait for data to be written (or 0 for no timeout).
  ++++++++++++++++++++++++++++++++++++++*/

ssize_t io_gnutls_write_with_timeout(io_gnutls *context,io_buffer *in,unsigned timeout)
{
 int n;

 if(in->length==0)
    return(0);

 context->w_timeout=timeout;

 if(in->length>(4*IO_BUFFER_SIZE))
   {
    size_t offset;
    io_buffer temp;

    temp.size=in->size;

    for(offset=0;offset<in->length;offset+=IO_BUFFER_SIZE)
      {
       temp.data=in->data+offset;

       temp.length=in->length-offset;
       if(temp.length>IO_BUFFER_SIZE)
          temp.length=IO_BUFFER_SIZE;

       n=io_gnutls_write_with_timeout(context,&temp,timeout);

       if(n<0)
         {
          in->length=0;
          return(n);
         }
      }

    in->length=0;
    return(in->length);
   }

 n=write_all(context->session,in->data,in->length);

 in->length=0;
 return(n);
}


/*++++++++++++++++++++++++++++++++++++++
  A function to write all of a buffer of data to a gnutls session.

  ssize_t write_all Returns the number of bytes written.

  gnutls_session_t session The gnutls session.

  const char *data The data buffer to write.

  size_t n The number of bytes to write.
  ++++++++++++++++++++++++++++++++++++++*/

static ssize_t write_all(gnutls_session_t session,const char *data,size_t n)
{
 int nn=0;

 /* Unroll the first loop to optimise the obvious case. */

 do
   {
    nn=gnutls_record_send(session,data,n);
   }
 while(nn==GNUTLS_E_INTERRUPTED || nn==GNUTLS_E_AGAIN);

 if(nn<0 || nn==n)
    return(nn);

 /* Loop around until the data is finished. */

 do
   {
    int m;

    do
      {
       m=gnutls_record_send(session,data+nn,n-nn);
      }
    while(m==GNUTLS_E_INTERRUPTED || m==GNUTLS_E_AGAIN);

    if(m<0)
      {n=m;break;}
    else
       nn+=m;
   }
 while(nn<n);

 return(n);
}


/*++++++++++++++++++++++++++++++++++++++
  Set the error status when there is a gnutls error.

  int err The error number.

  gnutls_session_t session The session information if one is active.
  ++++++++++++++++++++++++++++++++++++++*/

static void set_gnutls_error(int err,gnutls_session_t session)
{
 const char *type,*msg;

 if(err==GNUTLS_E_WARNING_ALERT_RECEIVED)
    type="Warning Alert:";
 else if(err==GNUTLS_E_FATAL_ALERT_RECEIVED)
    type="Fatal Alert:";
 else
    type="Error:";

 if(err==GNUTLS_E_WARNING_ALERT_RECEIVED || err==GNUTLS_E_FATAL_ALERT_RECEIVED)
   {
    if(session)
       msg=gnutls_alert_get_name(gnutls_alert_get(session));
    else
       msg="No session info";
   }
 else
    msg=gnutls_strerror(err);

 errno=ERRNO_USE_IO_ERRNO;

 if(io_strerror)
    free(io_strerror);
 io_strerror=(char*)malloc(16+strlen(type)+strlen(msg)+1);

 sprintf(io_strerror,"IO(gnutls): %s %s",type,msg);
}

#endif /* USE_GNUTLS */
