#---------------------------------------------------------------------
# Copyright (C) 2015  Seguesoft  Inc.
#                                                                             
# Redistribution of this software, in whole or in part, is prohibited         
# without the express written permission of Seguesoft. 
# Modified based on ncclient

import os, time
import socket
import socks
import getpass
from binascii import hexlify
from io import BytesIO
from select import select
import traceback

import ssl

from .errors import AuthenticationError, SessionCloseError, TLSError, TLSUnknownHostError
from .session import Session

import logging
logger = logging.getLogger("ncclient.transport.tls")
#logger.addHandler(logging.NullHandler())
BUF_SIZE = 4096
MSG_DELIM = "]]>]]>"
TICK = 0.1
MSG_DELIM_2="\n##\n"

DIGIT=['0','1','2','3','4','5','6','7','8','9']
DIGIT1=['1','2','3','4','5','6','7','8','9']

def default_unknown_host_cb(hostname):
    return False

class TLSSession(Session):
    "Implements a :rfc:`5539` NETCONF session over TLS."

    def __init__(self, capabilities, timeout=60):
        Session.__init__(self, capabilities, timeout=timeout)
        self._host="localhost"
        self._hostsrc="localhost"
        self._ciphers=None
                
        self._transport = None
        self._connected = False
        self._channel = None

        self._buffer = BytesIO() # for incoming data
        
        # parsing-related, see _parse()
        self._parsing_state = 0
        self._parsing_pos = 0
        
    def _parse(self):
        "Messages are delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes every time this method is called. Retains state across method calls and if a byte has been read it will not be considered again."
        delim = MSG_DELIM
        n = len(delim) 
        expect = self._parsing_state
        buf = self._buffer
        buf.seek(self._parsing_pos)
        while True:
            x = buf.read(1)
            if not x: # done reading
                break
            elif x.decode() == delim[expect]: # what we expected
                expect += 1 # expect the next delim char
            else:
                expect = 0
                continue
            # loop till last delim char expected, break if other char encountered
            for i in range(expect, n):
                x = buf.read(1)
                if not x: # done reading
                    break
                if x.decode() == delim[expect]: # what we expected
                    expect += 1 # expect the next delim char                    
                else:
                    expect = 0 # reset
                    break
            else: # if we didn't break out of the loop, full delim was parsed
                logger.debug('parsed new message')
                msg_till = buf.tell() - n
                buf.seek(0)      
                newmsg = buf.read(msg_till).strip()          
                logger.info(newmsg.decode())
                self._dispatch_message(newmsg.decode())
                    
                #buf.seek(n, os.SEEK_CUR)
                buf.seek(buf.tell() + n, os.SEEK_SET)
                rest = buf.read()
                #close current buf
                buf.close()
                #create a new buf to hold data
                buf = BytesIO()
                buf.write(rest)
                buf.seek(0)
                
                expect = 0
                
        self._buffer = buf
        self._parsing_state = expect
        self._parsing_pos = self._buffer.tell()
    
    def _parse2(self):
        rpcmsg=b""
        buf = self._buffer
        #test data.    
        #        data = """\n#4
        #<rpc
        ##16
        # message-id="0"
        #
        ##79
        #     xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
        #  <close-session/>
        #</rpc>
        ###
        #"""
        
        
        #
        # fileObject.seek(offset[, whence])      
        # A whence value of 0 measures from the beginning of the file,

        buf.seek(0)
        while True:
            x = buf.read(1)   
            if  len(x)==0: #EOF
                break
            elif x.decode() == '\n': 
                x = buf.read(1)
                if len(x)==0: # EOF
                    break                
                elif x.decode() == '#': # maybe chunk start "\n#12345\n"
                    #read in chunksize 1st number
                    x = buf.read(1)                     
                    if len(x)==0: #EOF
                        break
                    elif x.decode() in DIGIT1:
                        rpcmsg=self._readchunk(x, buf,rpcmsg)
                    elif x.decode() == '#':
                        #possible end-of-chunk-mark
                        x = buf.read(1)                            
                        if len(x)==0: #EOF
                            break
                        elif x.decode() == '\n': #end if chunks
                            #confirmed endofchunks mark encountered
                            self._dispatch_message(rpcmsg.decode())
                            #preserve the remaining data
                            rest = buf.read()
                            #close current buf
                            buf.close()
                            #create a new buf to hold the remaining data
                            buf = BytesIO()
                            buf.write(rest)
                            buf.seek(0)
                            #process the nex msg
                            continue
                    else:
                        #buf.seek(-2, os.SEEK_CUR)
                        # If the file is opened in text mode (without b), only offsets 
                        # returned by tell() are legal. Use of other offsets causes undefined behavior.

                        # Python3: In text files (those opened without a b in the mode string), 
                        # only seeks relative to the beginning of the file [os.SEEK_SET] are allowed...
                        #
                        buf.seek(buf.tell() -2, os.SEEK_SET) # os.SEEK_SET == 0

                else:
                    #buf.seek(-1, os.SEEK_CUR)
                    # If the file is opened in text mode (without b), only offsets 
                    # returned by tell() are legal. Use of other offsets causes
                    # undefined behavior.

                    # Python3: In text files (those opened without a b in the mode string), 
                    # only seeks relative to the beginning of the file [os.SEEK_SET]
                    # are allowed...
                    #
                    buf.seek(buf.tell() -1, os.SEEK_SET) # os.SEEK_SET == 0                    
                    
        #assign session buffer to the newly created one if that happened                
        self._buffer = buf 

    def _readchunk(self, x, buf, rpcmsg):
        #it is char 1-9
        size=x
        #parse out the remaining size digits
        while True:
            x=buf.read(1)
            if not x:
                break
            elif x.decode() in DIGIT:
                size += x #append another size digit
            elif x.decode() == "\n":   
                #chunk size ends
                intsize =int(size)
                logger.debug("Parsed out new chcunk size %d " % intsize)
                #make sure the intsize does not overflow
                if intsize > 4294967295:
                    #invalid size, terminate the session
                    logger.debug('chunk size invalid > 4294967295')
                    self.close()
                    raise SessionCloseError("invalid chunksize %s"%str(intsize))                                    
                    
                else:
                    #read this chunk in
                    msg_till = buf.read(intsize)
                    if not msg_till or len(msg_till) != intsize:
                        logger.debug("NOT Enough DATA!!! read in %d", len(msg_till))
                        #not enough data to read we need to accept in more data in socket layer
                        break
                    else:
                        rpcmsg += msg_till
                        
                        #check if this chunk contents have been read out completely       
                        x=buf.read(1)
                        if not x:
                            break #EOF data not complete for a whole message
                        elif x.decode() =='\n':
                            x=buf.read(1)
                            if not x:
                                break #data not complete for a whole message
                            elif x.decode() =='#':
                                #reset read pointer for another chunk and end-of-check-mark
                                # If the file is opened in text mode (without b), only offsets 
                                # returned by tell() are legal. Use of other offsets causes undefined behavior.

                                # Python3: In text files (those opened without a b in the mode string), 
                                # only seeks relative to the beginning of the file [os.SEEK_SET] are allowed...
                                #
                                #buf.seek(-2, os.SEEK_CUR)
                                buf.seek(buf.tell() -2, os.SEEK_SET) # os.SEEK_SET == 0
                                break                                            
                        else:
                            logger.debug('The length of the chunk does not match the chunk size value! Data remains...')
                            self.close()
                            raise SessionCloseError("The length of the chunk does not match the chunk size value! Data remains...")                                    
        return rpcmsg
            
        
    def close(self):
        if self._transport is not None:
            self._transport.close()
        self._connected = False
            

    # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there
    def connect(self, host, port=6513, timeout=None,
                unknown_host_cb=default_unknown_host_cb,
                client_sock =None,
                client_cert=None, client_key=None, trusted_certs=None,
                # Using custom tls_context instead of the default 
                # so that you can provide all kinds of constraints 
                # Python SSLContext allows
                tls_context=None,
                # instead of exposing a whole new context, we also 
                # provide two frequently used parameters: version / ciphers
                tls_minimum_version=None, 
                tls_maximum_version=None, 
                tls_ciphers=None,
                socks_proxy={"server":'', "port":1080, "type":"SOCKS5", "user": '', "password":''}):
        """Connect via TLS and initialize the NETCONF session.         """
        if timeout is None:
            self.timeout = 30
        else:    
            self.timeout = timeout
        if "server" not in  socks_proxy:
            socks_proxy["server"] = ''            
        if "type" not in  socks_proxy:
            socks_proxy["type"] = 'SOCKS5'            
        if "port" not in  socks_proxy:
            socks_proxy["port"] = '1080'            
        if "user" not in  socks_proxy:            
            socks_proxy["user"] = ''            
        if "password" not in  socks_proxy:
            socks_proxy["password"] = ''
            
        if client_sock is None:     
            sock = None
            for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
                af, socktype, proto, canonname, sa = res
                try:
                    # socks
                    if socks_proxy["server"] !="":
                        sock = socks.socksocket(af, socktype, proto) # Same API as socket.socket in the standard lib
                    else:                    
                        sock = socket.socket(af, socktype, proto)                
                    
                    sock.settimeout(timeout)
                except socket.error:
                    continue
                try:
                    # socks 
                    if socks_proxy["server"] != '':
                        rdns = False
                        typeID = socks.SOCKS4
                        if socks_proxy["type"] == "SOCKS5":
                            rdns = True
                            typeID = socks.SOCKS5
                        if socks_proxy["user"] != '':        
                            if socks_proxy["password"] != '':
                                sock.set_proxy(typeID, socks_proxy["server"], socks_proxy["port"], rdns, socks_proxy["user"], socks_proxy["password"]) 
                            else:
                                sock.set_proxy(typeID, socks_proxy["server"], socks_proxy["port"], rdns, socks_proxy["user"])    
                        else:
                            sock.set_proxy(typeID, socks_proxy["server"], socks_proxy["port"], rdns)
                    
                    sock.connect(sa)
                except socket.error:
                    sock.close()
                    continue
                break
            else:
                raise TLSError("Could not open socket to %s:%s" % (host, port))
        else:
            # eg. callhome
            sock = client_sock
                            
        # wrap socket to add SSL support        
        #context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=trusted_certs) 
        if tls_context is None:
            context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=trusted_certs)
        else:
            context = tls_context


        # Force TLS 1.3 only
        #context.minimum_version = ssl.TLSVersion.TLSv1_3
        #context.maximum_version = ssl.TLSVersion.TLSv1_3
        # You can print all OpenSSL-supported ciphers:
        #self._ciphers = [ (c['name'], c['protocol']) for c in context.get_ciphers() ]
        #print("Supported cyphers ", self._ciphers)
        """
        ciphers = context.get_ciphers()
        # Print table header
        print(f"{'Name':40} {'Protocol':10} {'Description'}")
        print("-"*80)

        # Print each cipher
        for c in ciphers:
            name = c.get('name', '')
            protocol = c.get('protocol', '')
            desc = c.get('description', '')
            print(f"{name:40} {protocol:10} {desc}")           
        """

        """
        # If you want to force TLS version + cipher suites while still using 
        # ssl.create_default_context(), you can, because the returned context is just an
        # SSLContext object that you can modify.        
        # Here is the correct pattern:

        # 2. Force TLS version (example: TLS 1.2 only)
        context.minimum_version = ssl.TLSVersion.TLSv1_2
        context.maximum_version = ssl.TLSVersion.TLSv1_2

        # 3. Force cipher suites
        # TLS 1.3 cipher suites are usually fixed
        # OpenSSL does not let you remove TLS1.3 ciphers easily.
        # Cipher list must match the server
        # If a server does not support your forced ciphers, handshake fails.

        Note:        
        TLS_AES_256_GCM_SHA384
        TLS_AES_128_GCM_SHA256
        These two are TLS 1.3 cipher suites.

        OpenSSL ignores them inside set_ciphers()
        Only TLS 1.2-and-below ciphers can be set with it.

        So this part:

        context.set_ciphers(
            "TLS_AES_256_GCM_SHA384:"
            "TLS_AES_128_GCM_SHA256:"
            "ECDHE-RSA-AES256-GCM-SHA384"
        )
        Only ECDHE-RSA-AES256-GCM-SHA384 is actually applied.
        The TLS 1.3 ones are ignored.
        This is an OpenSSL limitation.

        If you want only TLS 1.2, you must disable TLS 1.3:

        context.set_ciphers(
            "TLS_AES_256_GCM_SHA384:"
            "TLS_AES_128_GCM_SHA256:"
            "ECDHE-RSA-AES256-GCM-SHA384"
        )

        # so users can call like this:
        # syntax context.set_ciphers("cipher1:cipher2:cipher3")    
        tls_ciphers="ECDHE-RSA-AES256-GCM-SHA384"    
        """

        if tls_minimum_version:
            if tls_minimum_version != "Auto-negotiation":
                context.minimum_version = getattr(ssl.TLSVersion, tls_minimum_version)

        if tls_maximum_version:
            if tls_maximum_version != "Auto-negotiation":
                context.maximum_version = getattr(ssl.TLSVersion, tls_maximum_version)
        #print("???? tls_ciphers here ", tls_ciphers)
        if tls_ciphers:
            context.set_ciphers(tls_ciphers)       
            # Future Python release should have support for changing TLS1.3 ciphers      
            #context.set_ciphersuites(tls_ciphers) 
            
        context.check_hostname = False
        context.verify_mode = ssl.CERT_REQUIRED
        if client_cert !="" and client_cert is not None and client_key !="" and client_key is not None:            
            context.load_cert_chain(certfile=client_cert, keyfile=client_key)      
        sslsock = context.wrap_socket(sock, server_side=False)  

        print("TLS version in effect: ", sslsock.version())
        print("TLS cipher in effect: ", sslsock.cipher())
        
        self._host=host    

        self._hostsrc=sslsock.getsockname()[0]
        
        self._connected = True # there was no error in wrapping socket as a SSL sock
        
        # security hole here - there should be an error about mismatched host name
        # manual check of hostname
        cert = sslsock.getpeercert()
        

        # python2.7.9 has a new ssl.match_hostname(cert, hostname)
        # if hostname are Chinese this will cause checking to be skipped always 
        if hasattr(ssl, "match_hostname") and any(c.isalpha() for c in host):
        #if hasattr(ssl, "match_hostname"):
            try:
                sslsock.match_hostname(cert, host)
            except Exception:
                # CertificateError  is raised on error. If success nothing returned                    
                if not unknown_host_cb(host):
                    raise TLSUnknownHostError(host)
                
        self._channel = sslsock
        self._post_connect()
    

    def run(self):  
        chan = self._channel
        q = self._q

        # Put the SSL socket to non-blocking mode
        chan.setblocking(0)
                    
        try:
            while True:                                
                r, w, e = select([chan], [], [], TICK)
                # will wakeup evey TICK seconds to check if something to send, more if something to read (due to select returning chan in readable list)
                if r:
                    # Since select() needs a file descriptor, it's going to get the raw socket. 
                    # But even if the raw socket becomes readable, that doesn't mean you will 
                    # get data out of the SSL socket. You'll need to use non-blocking sockets 
                    # (which is a good idea anyway when using select()) and just ignore it if 
                    # it throws SSL_ERROR_WANT_READ (the SSL equivalent of EWOULDBLOCK).
                    
                    try:
                        data = chan.recv(BUF_SIZE)
                    except ssl.SSLError as e:
                        # Ignore the SSL equivalent of EWOULDBLOCK, but re-raise other errors
                        if e.errno != ssl.SSL_ERROR_WANT_READ:
                            raise
                        continue             
                                                       
                    if data:
                        # Another problem is, if you write 2048 bytes to the connection at the other 
                        # end, the select() on your end returns. But if you then only read 1024 bytes
                        # from the SSL socket, it is possible that the SSL socket internally reads
                        # more data, and the next select() won't return even though there would be
                        # more data to read, possibly deadlocking the connection. This is because
                        # the raw socket, which is what select() is using, doesn't have any data
                        # since it's already in the SSL socket's buffers.
                        
                        # Drain the SSL socket's internal buffer.
                        # If you want to remove the loop, make sure you don't call recv()
                        # with a 0 length, since that could cause a read to the raw socket.
                        data_left = chan.pending()
                        while data_left:
                            data += chan.recv(data_left)
                            data_left = chan.pending()
                                    
                        #logger.debug("TLS waiting.get: %s" %data)
                        self._buffer.write(data)
                        #print("TLS recv _delim_ver ", self._delim_ver)
                        if self._delim_ver == 1:
                            self._parse()
                        else:
                            self._parse2()                        
                    else:
                        raise SessionCloseError(self._buffer.getvalue())
                    
                    
                    
                if not q.empty():
                    logger.debug("Sending message")
                    if self._delim_ver == 1:
                        data = q.get().decode('utf-8') + MSG_DELIM
                    else:
                        DataXML=q.get().decode('utf-8')
                        if self.chunk_msg_test_buffer_len ==0:
                            data = "\n#%s\n"%str(len(DataXML))+DataXML+"\n##\n"
                        else:
                            #using wrong length
                            data = "\n#%s\n"%str(self.chunk_msg_test_buffer_len)+DataXML+"\n##\n"
                    
                    while data:
                        n = chan.send(data.encode("utf-8"))
                        if n <= 0:
                            raise SessionCloseError(self._buffer.getvalue(), data)
                        data = data[n:]

                
        except Exception as e:
            logger.debug("Broke out of main loop, error=%r", e)
            #raise
            #self.close()
            self._dispatch_error(repr(e) + traceback.format_exc())
            self.close()
            
            
    @property
    def transport(self):
        return self._channel
    
