#---------------------------------------------------------------------
# Copyright (C) 2013  Seguesoft  Inc.
#                                                                             
# Redistribution of this software, in whole or in part, is prohibited         
# without the express written permission of Seguesoft. 
# Modified based on ncclient
# ----------------------------------------------------------------------
# Copyright 2009 Shikhar Bhushan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import socket
import socks
import getpass
from binascii import hexlify
from io import BytesIO 
from select import select
import traceback
import paramiko

from .errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
from .session import Session

import logging
#logger = logging.getLogger("ncclient.transport.ssh")
#logging.basicConfig(level=logging.DEBUG, format="%(filename)s:%(lineno)s %(levelname)s:%(message)s")

logger = logging.getLogger(__name__)
#logger.addHandler(logging.NullHandler())
        
BUF_SIZE = 4096
MSG_DELIM = "]]>]]>"
TICK = 0.1
MSG_DELIM_2="\n##\n"

KH_FILE_NAME="~/.ssh/known_hosts"

ssh_original_ciphers = tuple(paramiko.Transport._preferred_ciphers)
DIGIT=['0','1','2','3','4','5','6','7','8','9']
DIGIT1=['1','2','3','4','5','6','7','8','9']
def default_unknown_host_cb(host, fingerprint, key):
    """An unknown host callback returns `True` if it finds the key acceptable, and `False` if not.

    This default callback always returns `False`, which would lead to :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
    
    Supply another valid callback if you need to verify the host key programatically.

    *host* is the hostname that needs to be verified

    *fingerprint* is a hex string representing the host key fingerprint, colon-delimited e.g. `"4b:69:6c:72:6f:79:20:77:61:73:20:68:65:72:65:21"`
    """
    return False

def _colonify(fp):
    finga = fp[:2].decode("utf-8") 
    for idx  in range(2, len(fp), 2):
        finga += ":" + fp[idx:idx+2].decode("utf-8") 
    return finga

class SSHSession(Session):
    "Implements a :rfc:`4742` NETCONF session over SSH."

    def __init__(self, capabilities, timeout=60):
        Session.__init__(self, capabilities, timeout=timeout)
        self._host="localhost"
        self._hostsrc="localhost"
        # paramiko.Transport._preferred_ciphers is a tuple (unmutable so we can get a copy of refeerence)
        
        self._host_keys = paramiko.HostKeys()
        self._transport = None
        self._connected = False
        self._channel = None

        self._buffer = BytesIO() # for incoming data
        # parsing-related, see _parse()
        self._parsing_state = 0
        self._parsing_pos = 0
        
    def _parse(self):
        "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again."
        delim = MSG_DELIM
        n = len(delim) 
        expect = self._parsing_state
        buf = self._buffer
        buf.seek(self._parsing_pos)
        while True:
            x = buf.read(1)
            #print("x got ", x)
            if not x: # done reading
                break
            elif x.decode() == delim[expect]: # what we expected
                expect += 1 # expect the next delim char
            else:
                expect = 0
                continue
            # loop till last delim char expected, break if other char encountered
            for i in range(expect, n):
                x = buf.read(1)
                if not x: # done reading
                    break
                if x.decode() == delim[expect]: # what we expected
                    #print("expect ", expect, "got ", x) 
                    expect += 1 # expect the next delim char                    
                else:
                    expect = 0 # reset
                    break
            else: # if we didn't break out of the loop, full delim was parsed
                #print('Parsed a new message:\n')
                msg_till = buf.tell() - n
                buf.seek(0)                
                newmsg = buf.read(msg_till).strip()
                logger.info(newmsg.decode())

                self._dispatch_message(newmsg.decode())
                    
                # Python3 In text files (those opened without a b 
                # in the mode string),
                # only seeks relative to the beginning of the file 
                # [os.SEEK_SET] are allowed...
                #buf.seek(n, os.SEEK_CUR)
                buf.seek(buf.tell() + n, os.SEEK_SET) 

                rest = buf.read()
                #close current buf
                buf.close()
                #create a new buf to hold data
                buf = BytesIO()
                buf.write(rest)
                buf.seek(0)
                
                expect = 0
                
        self._buffer = buf
        self._parsing_state = expect
        self._parsing_pos = self._buffer.tell()
    
    def _parse2(self):
        rpcmsg=b""
        buf = self._buffer        
        # test data 
        '''   
        testData= buf.getvalue()
        if len(testData.decode()) > 100 and testData.decode().find("ietf-yang-library") == -1:
            data="""
#982
<?xml version="1.0" encoding="UTF-8"?><rpc-reply message-id="b2e023be-5d82-11ec-bd2b-00133b190f3c" xmlns="urn:ietf:params:xml:ns:netconf:base:1.0"><data xmlns="urn:ietf:params:xml:ns:yang:ietf-netconf-monitoring">
module ieee802-dot1q-bridge-sched {
  namespace urn:ieee:std:802.1Q:yang:ieee802-dot1q-bridge-sched;
  prefix bridge-sched;

  import ietf-interfaces {
    prefix if;
  }
  import ieee802-dot1q-bridge {
    prefix dot1q;
  }

  import ieee802-dot1q-sched {
    prefix sched;
  }
  organization
    "IEEE 802.1 Working Group";
  contact
    "WG-URL: http://www.ieee802.org/1/
    WG-EMail: stds-802-1-l@ieee.org

    Contact: IEEE 802.1 Working Group Chair
    Postal: C/O IEEE 802.1 Working Group
    IEEE Standards Association
    445 Hoes Lane
    Piscataway, NJ 08854
    USA

    E-mail: STDS-802-1-CHAIRS@IEEE.ORG";
  description
    "This module provides for management of IEEE Std 802.1Q Bridges
    that support Scheduled Traffic Enhancements.";

  revisio
#519
n 2021-02-02 {
    description
     "Published as part of IEEE Std 802.1Qcw.
      Initial version.";
    reference
      "IEEE Std 802.1Qcw - Bridges and Bridged Networks — Amendment:
      YANG Data Models for Scheduled Traffic, Frame Preemption, and
      Per-Stream Filtering and Policing.";
  }


  augment "/if:interfaces/if:interface/dot1q:bridge-port" {

    description
      "Augment bridge-port with Scheduled Traffic configuration.";

   uses sched:sched-parameters;
  }
}
</data></rpc-reply>
##
"""       
            buf =BytesIO(data.encode())      
        '''

        # Done with test data          
        buf.seek(0)
        while True:
            x = buf.read(1)   
            #print("x ", x)
            if  len(x)==0: #EOF
                #print "break "
                break
            elif x.decode() == '\n': 
                #print("x1 " , x)
                x = buf.read(1)
                if len(x)==0: # EOF
                    break                
                elif x.decode() == '#': # maybe chunk start "\n#12345\n"
                    #print("x2 " , x)
                    #read in chunksize 1st number
                    x = buf.read(1)                     
                    if len(x)==0: #EOF
                        #print("break2 ")
                        break
                    elif x.decode() in DIGIT1:
                        rpcmsg=self._readchunk(x, buf,rpcmsg)
                    elif x.decode() == '#':
                        #possible end-of-chunk-mark
                        #print("MAY be end of mark")
                        x = buf.read(1)                            
                        if len(x)==0: #EOF
                            break
                        elif x.decode() == '\n': #end of chunks                     
                            #confirmed endofchunks mark encountered
                            # debugging reply
                            #print("parse2 got: ", rpcmsg.decode())
                            #if rpcmsg.find("hello") == -1:
                            #    rpcmsg="""<?xml version='1.0' encoding='UTF-8'?>
                            #    <rpc-reply xmlns:ncx="http://netconfcentral.org/ns/yuma-ncx" xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
                            #    <rpc-error>
                            #      <error-type>transport</error-type>
                            #      <error-tag>malformed-message</error-tag>
                            #      <error-severity>error</error-severity>
                            #      <error-app-tag>data-invalid</error-app-tag>
                            #      <error-message xml:lang="en">invalid protocol framing characters received</error-message>
                            #    </rpc-error>
                            #    </rpc-reply>"""
                            
                            self._dispatch_message(rpcmsg.decode())
                            #preserve the remaining data
                            rest = buf.read()
                            #print("rest data: ", rest)
                            #close current buf
                            buf.close()
                            #create a new buf to hold the remaining data
                            buf = BytesIO()
                            buf.write(rest)
                            buf.seek(0)
                            #process the next msg
                            continue
                    else:
                        #print("not matched x ", x)
                        #buf.seek(-2, os.SEEK_CUR)
                        buf.seek(buf.tell() -2, os.SEEK_SET) # os.SEEK_SET == 0
                else:
                    #print("not matched x1 ", x)
                    #buf.seek(-1, os.SEEK_CUR)
                    buf.seek(buf.tell() -1, os.SEEK_SET) # os.SEEK_SET == 0   
                    
        #assign session buffer to the newly created one if that happened                
        self._buffer = buf 

    def _readchunk(self, x, buf, rpcmsg):
        # it is char 1-9
        size=x
        #print("found size digit  ", x)
        # parse out the remaining size digits

        while True:
            x=buf.read(1)
            if not x:
                break
            elif x.decode() in DIGIT:
                #print("found size digit  ", x)
                size += x #append another size digit
            elif x.decode() == "\n":   
                #chunk size ends
                intsize =int(size)
                #print("found new real size   ", intsize)
                #make sure the intsize does not overflow
                if intsize > 4294967295:
                    #invalid size, terminate the session
                    logger.debug('chunk size invalid > 4294967295')
                    self.close()
                    raise SessionCloseError("invalid chunksize %s"%str(intsize))                                    
                    
                else:
                    #read this chunk in
                    #print("intsize", intsize)
                    msg_till = buf.read(intsize)
                    #print("msg_till ", msg_till, "len(msg_till) ", len(msg_till))
                    if not msg_till or len(msg_till) != intsize:
                        #print("NOT Enough DATA!!!")
                        # logger.debug("NOT Enough DATA read in %d", len(msg_till))
                        #not enough data to read we need to accept in more data in socket layer
                        break
                    else:
                        rpcmsg += msg_till
                        
                        #check if this chunk contents have been read out completely       
                        x=buf.read(1)
                        #print("x3 : ", x.decode())
                        if not x:
                            break #EOF data not complete for a whole message
                        elif x.decode() =='\n':
                            x=buf.read(1)
                            if not x:
                                break #data not complete for a whole message
                            elif x.decode() =='#':
                                #print("got  all data for a chunk")
                                #reset read pointer for another chunk 
                                # and end-of-check-mark
                                #print "YES next chunk or end"
                                #buf.seek(-2, os.SEEK_CUR)
                                buf.seek(buf.tell()-2, os.SEEK_SET)
                                break                                            
                        else:
                            #print("unexpected chunk framing: ", x)
                            logger.debug('The length of the chunk does not match the chunk size value! Data remains...')
                            self.close()
                            #raise SessionCloseError("The length of the chunk does not match the chunk size value! rpcmsg read: " +
                            #    rpcmsg.decode() + " \nData remains... All msg: " + buf.getvalue().decode() )                            
                            raise SessionCloseError(b"The length of the chunk does not match the chunk size value! rpcmsg read: " +
                                rpcmsg + b" \nData remains... \nAll msg: " + buf.getvalue())                                    
        return rpcmsg
            
    def load_known_hosts(self, filename=None):
        """Load host keys from an openssh :file:`known_hosts`-style file. Can be called multiple times.

        If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
        
        """
        
        if filename is None:
            if os.pathsep == ";":
                filename = os.path.expanduser(KH_FILE_NAME)              
            else: 
                filename = os.path.expanduser(KH_FILE_NAME)                 

            #print("userssh filename ", filename, " isfile ", os.path.isfile(filename))
            try:
                self._host_keys.load(filename)
            except Exception:
                pass             
        else:
            self._host_keys_filename = filename             
            self._host_keys.load(filename)

    def save_host_keys(self, filename=None):
        if filename is None:
           if os.pathsep == ";":
               filename = os.path.expanduser(KH_FILE_NAME)              
           else: 
               filename = os.path.expanduser(KH_FILE_NAME)              
        #print("dir ", os.path.dirname(filename))
        os.makedirs(os.path.dirname(filename), exist_ok=True)

        f = open(filename, 'w') 
        f.write('# SSH host keys collected by paramiko\n') 
        for hostname, keys in self._host_keys.items(): 
          for keytype, key in keys.items(): 
              f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) 
        f.close() 
        
    def close(self):
        if self._transport.is_active():
            self._transport.close()
        self._connected = False

    # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there
    # paramiko no existing session exception
    # As you already have password you don't need to talk to agent or look for private keys stored on your machine. 
    # So try passing extra parameters allow_agent, look_for_keys:
    
    
    def connect(self, host, port=830, timeout=None, unknown_host_cb=default_unknown_host_cb,
                username=None, password=None, key_filename=None, allow_agent=False, look_for_keys=False,
                client_sock =None,
                socks_proxy={"server":'', "port":1080, "type":"SOCKS5", "user": '', "password":''},
                ssh_ciphers=None ):
        """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication.

        To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`.

        *host* is the hostname or IP address to connect to

        *port* is by default 830, but some devices use the default SSH port of 22 so this may need to be specified

        *client_sock* is a client socket that is already connected (callhome)
        *timeout* is an optional timeout for socket connect

        *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`)

        *username* is the username to use for SSH authentication

        *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it

        *key_filename* is a filename where a the private key to be used can be found

        *allow_agent* enables querying SSH agent (if found) for keys

        *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
        
        *socks_proxy : {"server":'', "port":1080, "type":"SOCKS5", "user": '', "password":''} 
        
        *ssh_ciphers: ['aes256-gcm@openssh.com', 'aes128-gcm@openssh.com', 'aes256-ctr', 'aes192-ctr', 'aes128-ctr'], or
                          a colon separated string: 'aes256-gcm@openssh.com:aes128-gcm@openssh.com'

        
        """
        if timeout is None:
            self.timeout = 30
        else:    
            self.timeout = timeout
            
        if username is None:
            username = getpass.getuser()


        # --- Ensure socks_proxy dict has all keys ---
        # should use the following simpler snippet
        #for k, default in [("server", ""), ("type", "SOCKS5"), ("port", 1080), ("user", ""), ("password", "")]:
        #    socks_proxy.setdefault(k, default)

        if "server" not in  socks_proxy:
            socks_proxy["server"] = ''            
        if "type" not in  socks_proxy:
            socks_proxy["type"] = 'SOCKS5'            
        if "port" not in  socks_proxy:
            socks_proxy["port"] = '1080'            
        if "user" not in  socks_proxy:            
            socks_proxy["user"] = ''            
        if "password" not in  socks_proxy:
            socks_proxy["password"] = ''
            
        if client_sock is None:    
            #print  "ssh socks_proxy ", socks_proxy
            sock = None
            for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
                af, socktype, proto, canonname, sa = res
                try:
                    # socks
                    if socks_proxy["server"] != "":
                        sock = socks.socksocket(af, socktype, proto) # Same API as socket.socket in the standard lib
                    else:                    
                        sock = socket.socket(af, socktype, proto)
                    sock.settimeout(timeout)
                except socket.error:
                    continue
                try:
                    # socks 
                    if  socks_proxy["server"] != "":
                        rdns = False
                        typeID = socks.SOCKS4
                        if socks_proxy["type"] == "SOCKS5":
                            rdns = True
                            typeID = socks.SOCKS5
                                
                        if socks_proxy["user"] != "":        
                            if socks_proxy["password"] !="":
                                sock.set_proxy(typeID, socks_proxy["server"], socks_proxy["port"], rdns, socks_proxy["user"], socks_proxy["password"]) 
                            else:
                                sock.set_proxy(typeID, socks_proxy["server"], socks_proxy["port"], rdns, socks_proxy["user"])    
                        else:
                            sock.set_proxy(typeID, socks_proxy["server"], socks_proxy["port"], rdns)
                    
                    sock.connect(sa)
                except socket.error:
                    sock.close()
                    continue
                break
            else:
                raise SSHError("Could not open socket to %s:%s" % (host, port))
        else:
            # eg. callhome
            sock = client_sock


        # --- Set preferred ciphers if provided ---
        # --- Validate and apply user preferred ciphers ---
        if ssh_ciphers:
            # Accept colon-separated string OR list/tuple
            if isinstance(ssh_ciphers, str):
                # "aes256-ctr:aes128-ctr" → ["aes256-ctr", "aes128-ctr"]
                ssh_ciphers = [
                    c.strip() for c in ssh_ciphers.split(":") if c.strip()
                ]

            elif isinstance(ssh_ciphers, (list, tuple)):
                ssh_ciphers = list(ssh_ciphers)

            else:
                raise ValueError(
                    "ssh_ciphers must be a colon-separated string or a list/tuple of strings"
                )

            # Validate contents
            for c in ssh_ciphers:
                if not isinstance(c, str):
                    raise ValueError("Each cipher must be a string")

            # Apply to Paramiko (global setting)
            #logger.error("!!!!passed in ssh_ciphers %s"%repr(ssh_ciphers))
            paramiko.Transport._preferred_ciphers = ssh_ciphers
        else:    
            # Apply to Paramiko (global setting)
            #logger.error("???passed in ssh_ciphers %s"%repr(ssh_original_ciphers))
            paramiko.Transport._preferred_ciphers = ssh_original_ciphers
                
        t = self._transport = paramiko.Transport(sock)
        
        t.set_log_channel(logger.name)
        logger.debug("Starting client...")
        try:
            t.start_client()
        except paramiko.SSHException:
            raise SSHError('Negotiation failed')

        # host key verification
        server_key = t.get_remote_server_key()
        known_host = self._host_keys.check(host, server_key)

        fingerprint = _colonify(hexlify(server_key.get_fingerprint()))

        if not known_host and not unknown_host_cb(host, fingerprint, server_key):
            raise SSHUnknownHostError(host, fingerprint)

        #self._host=sock.getpeername()
        self._host=host    

        self._hostsrc=sock.getsockname()[0]
        
        if key_filename is None:
            key_filenames = []
        elif isinstance(key_filename, str):
            key_filenames = [ key_filename ]
        else:
            key_filenames = key_filename

        self._auth(username, password, key_filenames, allow_agent, look_for_keys)

        self._connected = True # there was no error authenticating
        #logger.debug("Invoking NETCONF subsystem ...")
        c = self._channel = self._transport.open_session()
        c.set_name("netconf")
        c.invoke_subsystem("netconf")


        self._post_connect()
    
    # on the lines of paramiko.SSHClient._auth()
    def _auth(self, username, password, key_filenames, allow_agent,
              look_for_keys):
        saved_exception = None
        for key_filename in key_filenames:
            for cls in (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey, paramiko.Ed25519Key):
                try:
                    key = cls.from_private_key_file(key_filename, password)
                    logger.debug("Trying key %s from %s" %
                              (hexlify(key.get_fingerprint()), key_filename))
                    self._transport.auth_publickey(username, key)
                    return
                except Exception as e:
                    saved_exception = e
                    logger.debug("Reading private key file got exception: " + repr(e) +"\nNot PEM format?")

        if allow_agent:
            for key in paramiko.Agent().get_keys():
                try:
                    logger.debug("Trying SSH agent key %s" %
                                 hexlify(key.get_fingerprint()))
                    self._transport.auth_publickey(username, key)
                    return
                except Exception as e:
                    saved_exception = e
                    logger.debug(e)

        keyfiles = []
        if look_for_keys:
            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
            if os.path.isfile(rsa_key):
                keyfiles.append((paramiko.RSAKey, rsa_key))
            if os.path.isfile(dsa_key):
                keyfiles.append((paramiko.DSSKey, dsa_key))
            # look in ~/ssh/ for windows users:
            rsa_key = os.path.expanduser("~/ssh/id_rsa")
            dsa_key = os.path.expanduser("~/ssh/id_dsa")
            if os.path.isfile(rsa_key):
                keyfiles.append((paramiko.RSAKey, rsa_key))
            if os.path.isfile(dsa_key):
                keyfiles.append((paramiko.DSSKey, dsa_key))

        for cls, filename in keyfiles:
            try:
                key = cls.from_private_key_file(filename, password)
                logger.debug("Trying discovered key %s in %s" %
                          (hexlify(key.get_fingerprint()), filename))
                self._transport.auth_publickey(username, key)
                return
            except Exception as e:
                saved_exception = e
                logger.debug("Tring look_for_keys failed: " + repr(e))

        if password is not None and password.strip() != "":
            try:
                self._transport.auth_password(username, password)
                return
            except Exception as e:
                saved_exception = e
                logger.debug(e)

        if saved_exception is not None:
            # need pep-3134 to do this right
            raise AuthenticationError(repr(saved_exception))

        raise AuthenticationError("No authentication methods available")

    def run(self):  
        chan = self._channel
        q = self._q

        try:
            while True:
                # select on a paramiko ssh channel object does not ever return it in the writable list, so channels don't exactly emulate the socket api
                r, w, e = select([chan], [], [], TICK)
                # will wakeup evey TICK seconds to check if something to send, more if something to read (due to select returning chan in readable list)
                #print "ssh waiting..."                   
                if r:
                    data = chan.recv(BUF_SIZE)
                    if data:
                        #print("SSH waiting.get: %s" %data.decode("utf-8") )
                        #self._buffer.write(data.decode("utf-8") )
                        self._buffer.write(data)
                        #print("SSH recv _delim_ver ", self._delim_ver)
                        if self._delim_ver == 1:
                            self._parse()
                        else:
                            self._parse2()                        
                    else:
                        raise SessionCloseError(self._buffer.getvalue().decode())
                if not q.empty() and chan.send_ready():
                    logger.debug("Sending message")
                    #print "send _delim_ver ", self._delim_ver
                    if self._delim_ver == 1:
                        data = q.get().decode("utf-8") + MSG_DELIM
                    else:
                        DataXML=q.get().decode("utf-8")
                        if self.chunk_msg_test_buffer_len ==0:
                            data = "\n#%s\n"%str(len(DataXML))+DataXML+"\n##\n"
                        else:
                            #using wrong length in sstest
                            data = "\n#%s\n"%str(self.chunk_msg_test_buffer_len)+DataXML+"\n##\n"
  
                    
                    while data:
                        n = chan.send(data)
                        if n <= 0:
                            raise SessionCloseError(self._buffer.getvalue(), data)
                        data = data[n:]

                
        except Exception as e:
            logger.debug("Broke out of main loop, error=%r", e)
            #raise
            #self.close()
            self._dispatch_error(repr(e) + traceback.format_exc())
            self.close()
            
            
    @property
    def transport(self):
        "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it."
        return self._transport
