#!/usr/bin/python
# TorCtl.py -- Python module to interface with Tor Control interface.
# Copyright 2005 Nick Mathewson -- See LICENSE for licensing information.
#$Id: TorCtl0.py,v 1.7 2005/11/19 19:42:31 nickm Exp $

"""
TorCtl0 -- Library to control Tor processes.  See TorCtlDemo.py for example use.
"""

import binascii
import os
import sha
import socket
import struct
import sys
import TorCtl

__all__ = [
    "MSG_TYPE", "EVENT_TYPE", "CIRC_STATUS", "STREAM_STATUS",
    "OR_CONN_STATUS", "SIGNAL", "ERR_CODES",
    "TorCtlError", "ProtocolError", "ErrorReply", "Connection", "EventHandler",
    ]

class _Enum:
    # Helper: define an ordered dense name-to-number 1-1 mapping.
    def __init__(self, start, names):
        self.nameOf = {}
        idx = start
        for name in names:
            setattr(self,name,idx)
            self.nameOf[idx] = name
            idx += 1

class _Enum2:
    # Helper: define an ordered sparse name-to-number 1-1 mapping.
    def __init__(self, **args):
        self.__dict__.update(args)
        self.nameOf = {}
        for k,v in args.items():
            self.nameOf[v] = k

# Message types that client or server can send.
MSG_TYPE = _Enum(0x0000,
                 ["ERROR",
                  "DONE",
                  "SETCONF",
                  "GETCONF",
                  "CONFVALUE",
                  "SETEVENTS",
                  "EVENT",
                  "AUTH",
                  "SAVECONF",
                  "SIGNAL",
                  "MAPADDRESS",
                  "GETINFO",
                  "INFOVALUE",
                  "EXTENDCIRCUIT",
                  "ATTACHSTREAM",
                  "POSTDESCRIPTOR",
                  "FRAGMENTHEADER",
                  "FRAGMENT",
                  "REDIRECTSTREAM",
                  "CLOSESTREAM",
                  "CLOSECIRCUIT",
                  ])

# Make sure that the enumeration code is working.
assert MSG_TYPE.SAVECONF == 0x0008
assert MSG_TYPE.CLOSECIRCUIT == 0x0014

# Types of "EVENT" message.
EVENT_TYPE = _Enum(0x0001,
                   ["CIRCSTATUS",
                    "STREAMSTATUS",
                    "ORCONNSTATUS",
                    "BANDWIDTH",
                    "OBSOLETE_LOG",
                    "NEWDESC",
                    "DEBUG_MSG",
                    "INFO_MSG",
                    "NOTICE_MSG",
                    "WARN_MSG",
                    "ERR_MSG",
                    ])

assert EVENT_TYPE.ERR_MSG == 0x000B
assert EVENT_TYPE.OBSOLETE_LOG == 0x0005

# Status codes for "CIRCSTATUS" events.
CIRC_STATUS = _Enum(0x00,
                    ["LAUNCHED",
                     "BUILT",
                     "EXTENDED",
                     "FAILED",
                     "CLOSED"])

# Status codes for "STREAMSTATUS" events
STREAM_STATUS = _Enum(0x00,
                      ["SENT_CONNECT",
                       "SENT_RESOLVE",
                       "SUCCEEDED",
                       "FAILED",
                       "CLOSED",
                       "NEW_CONNECT",
                       "NEW_RESOLVE",
                       "DETACHED"])

# Status codes for "ORCONNSTATUS" events
OR_CONN_STATUS = _Enum(0x00,
                       ["LAUNCHED","CONNECTED","FAILED","CLOSED"])

# Signal codes for "SIGNAL" events.
SIGNAL = _Enum2(HUP=0x01,INT=0x02,USR1=0x0A,USR2=0x0C,TERM=0x0F)

# Error codes for "ERROR" events.
ERR_CODES = {
  0x0000 : "Unspecified error",
  0x0001 : "Internal error",
  0x0002 : "Unrecognized message type",
  0x0003 : "Syntax error",
  0x0004 : "Unrecognized configuration key",
  0x0005 : "Invalid configuration value",
  0x0006 : "Unrecognized byte code",
  0x0007 : "Unauthorized",
  0x0008 : "Failed authentication attempt",
  0x0009 : "Resource exhausted",
  0x000A : "No such stream",
  0x000B : "No such circuit",
  0x000C : "No such OR"
}

def _unpack_singleton_msg(msg):
    """Helper: unpack a single packet.  Return (None, minLength, body-so-far)
       on incomplete packet or (type,body,rest) on somplete packet
    """
    if len(msg) < 4:
        return None, 4, msg
    length,type = struct.unpack("!HH",msg)
    if len(msg) >= 4+length:
        return type,msg[4:4+length],msg[4+length:]
    else:
        return None,4+length,msg

def _minLengthToPack(bytes):
    """Return the minimum number of bytes needed to pack the message 'smg'"""
    whole,left = divmod(bytes,65535)
    if left:
        return whole*(65535+4)+4+left
    else:
        return whole*(65535+4)

def _unpack_msg(msg):
    "returns as for _unpack_singleton_msg"
    tp,body,rest = _unpack_singleton_msg(msg)
    if tp != MSG_TYPE.FRAGMENTHEADER:
        return tp, body, rest

    if len(body) < 6:
        raise ProtocolError("FRAGMENTHEADER message too short")

    realType,realLength = struct.unpack("!HL", body[:6])

    # Okay; could the message _possibly_ be here?
    minLength = _minLengthToPack(realLength+6)
    if len(msg) < minLength:
        return None,  minLength, msg

    # Okay; optimistically try to build up the msg.
    soFar = [ body[6:] ]
    lenSoFarLen = len(body)-6
    while len(rest)>=4 and lenSoFar < realLength:
        ln, tp = struct.unpack("!HH", rest[:4])
        if tp != MSG_TYPE.FRAGMENT:
            raise ProtocolError("Missing FRAGMENT message")
        soFar.append(rest[4:4+ln])
        lenSoFar += ln
        if 4+ln > len(rest):
            rest = ""
            leftInPacket = 4+ln-len(rest)
        else:
            rest = rest[4+ln:]
            leftInPacket=0

    if lenSoFar == realLength:
        return realType, "".join(soFar), rest
    elif lenSoFar > realLength:
        raise ProtocolError("Bad fragmentation: message longer than declared")
    else:
        inOtherPackets = realLength-lenSoFar-leftInPacket
        minLength = _minLengthToPack(inOtherPackets)
        return None, len(msg)+leftInPacket+inOtherPackets, msg

def _receive_singleton_msg(s):
    """Read a single packet from the socket s.
    """
    body = ""
    header = s.recv(4)
    if not header:
        raise TorCtl.TorCtlClosed()
    length,type = struct.unpack("!HH",header)
    if length:
        while length > len(body):
            more = s.recv(length-len(body))
            if not more:
                raise TorCtl.TorCtlClosed()
            body += more
    return length,type,body

def _receive_message(s):
    """Read a single message (possibly multi-packet) from the socket s."""
    length, tp, body = _receive_singleton_msg(s)
    if tp != MSG_TYPE.FRAGMENTHEADER:
        return length, tp, body
    if length < 6:
        raise ProtocolError("FRAGMENTHEADER message too short")
    realType,realLength = struct.unpack("!HL", body[:6])
    data = [ body[6:] ]
    soFar = len(data[0])
    while 1:
        length, tp, body = _receive_singleton_msg(s)
        if tp != MSG_TYPE.FRAGMENT:
            raise ProtocolError("Missing FRAGMENT message")
        soFar += length
        data.append(body)
        if soFar == realLength:
            return realLength, realType, "".join(data)
        elif soFar > realLength:
            raise ProtocolError("FRAGMENT message too long!")

def pack_message(type, body=""):
    """Given a message type and optional message body, generate a set of
       packets to send.
    """
    length = len(body)
    if length < 65536:
        reqheader = struct.pack("!HH", length, type)
        return "%s%s"%(reqheader,body)

    fragheader = struct.pack("!HHHL",
                             65535, MSG_TYPE.FRAGMENTHEADER, type, length)
    msgs = [ fragheader, body[:65535-6] ]
    body = body[65535-6:]
    while body:
        if len(body) > 65535:
            fl = 65535
        else:
            fl = len(body)
        fragheader = struct.pack("!HH", MSG_TYPE.FRAGMENT, fl)
        msgs.append(fragheader)
        msgs.append(body[:fl])
        body = body[fl:]

    return "".join(msgs)

def _parseKV(body,sep=" ",term="\n"):
    """Helper: parse a key/value list of the form [key sep value term]* .
       Return a list of (k,v)."""
    res = []
    for line in body.split(term):
        if not line: continue
        k, v = line.split(sep,1)
        res.append((k,v))
    return res

def _unterminate(s):
    """Strip trailing NUL characters from s."""
    if s[-1] == '\0':
        return s[:-1]
    else:
        return s

class Connection(TorCtl._ConnectionBase):
    """A Connection represents a connection to the Tor process."""
    def __init__(self, sock):
        """Create a Connection to communicate with the Tor process over the
           socket 'sock'.
        """
        TorCtl._ConnectionBase.__init__(self)
        self._s = sock

    def set_event_handler(self, handler):
        """Cause future events from the Tor process to be sent to 'handler'.
        """
        self._handler = handler
        self._handleFn = handler.handle0

    def _doSend(self, (type, body)):
        """Helper: Deliver a command of type 'type' and body 'body' to Tor.
        """
        self._s.sendall(pack_message(type, body))

    def _read_reply(self):
        length, tp, body = _receive_message(self._s)
        return (tp == MSG_TYPE.EVENT), (tp, body)

    def _sendAndRecv(self, tp, msg="", expectedTypes=(MSG_TYPE.DONE,)):
        """Helper: Send a command of type 'tp' and body 'msg' to Tor,
           and wait for a command in response.  If the response type is
           in expectedTypes, return a (tp,body) tuple.  If it is an error,
           raise ErrorReply.  Otherwise, raise ProtocolError.
        """

        tp, msg = self.sendImpl(self._doSend, (tp, msg))
        if tp in expectedTypes:
            return tp, msg
        elif tp == MSG_TYPE.ERROR:
            if len(msg)<2:
                raise ProtocolError("(Truncated error message)")
            errCode, = struct.unpack("!H", msg[:2])
            raise ErrorReply((errCode,
                              ERR_CODES.get(errCode,"[unrecognized]"),
                              msg[2:]))
        else:
            raise ProtocolError("Unexpectd message type 0x%04x"%tp)

    def authenticate(self, secret=""):
        """Send an authenticating secret to Tor.  You'll need to call
           this method before other commands.  You need to use a
           password if Tor expects one.
        """
        self._sendAndRecv(MSG_TYPE.AUTH,secret)

    def get_option(self,name):
        """Get the value of the configuration option named 'name'.  To
           retrieve multiple values, pass a list for 'name' instead of
           a string.  Returns a list of (key,value) pairs.
        """
        if not isinstance(name, str):
            name = "".join(["%s\n"%s for s in name])
        tp,body = self._sendAndRecv(MSG_TYPE.GETCONF,name,[MSG_TYPE.CONFVALUE])
        return _parseKV(body)

    def set_option(self,key,value):
        """Set the value of the configuration option 'key' to the
           value 'value'.
        """
        self.set_options([key, value])

    def set_options(self,kvlist):
        """Given a list of (key,value) pairs, set them as configuration
           options.
        """
        msg = "".join(["%s %s\n" for k,v in kvlist])
        self._sendAndRecv(MSG_TYPE.SETCONF,msg)

    def reset_options(self, keylist):
        msg = "".join(["%s\n" for k in keylist])
        self._sendAndRecv(MSG_TYPE.SETCONF,msg)

    def get_info(self,name):
        """Return the value of the internal information field named
           'name'.  To retrieve multiple values, pass a list for
           'name' instead of a string.  Returns a dictionary of
           key->value mappings.
        """
        if not isinstance(name, str):
            name = "".join(["%s\n"%s for s in name])
        tp, body = self._sendAndRecv(MSG_TYPE.GETINFO,name,[MSG_TYPE.INFOVALUE])
        kvs = body.split("\0")
        d = {}
        for i in xrange(0,len(kvs)-1,2):
            d[kvs[i]] = kvs[i+1]
        return d

    def set_events(self,events):
        """Change the list of events that the event handler is interested
           in to those in 'events', which is a list of EVENT_TYPE members
           or their corresponding strings.
        """
        evs = []
        for ev in events:
            if isinstance(ev, types.StringType):
                evs.append(getattr(EVENT_TYPE, ev.upper()))
            else:
                evs.append(ev)
        self._sendAndRecv(MSG_TYPE.SETEVENTS,
                     "".join([struct.pack("!H", event) for event in events]))

    def save_conf(self):
        """Flush all configuration changes to disk.
        """
        self._sendAndRecv(MSG_TYPE.SAVECONF)

    def send_signal(self, sig):
        """Send the signal 'sig' to the Tor process; 'sig' must be a member of
           SIGNAL or a corresponding string.
        """
        try:
            sig = sig.upper()
        except AttributeError:
            pass
        sig = { "HUP" : 0x01, "RELOAD" : 0x01,
                "INT" : 0x02, "SHUTDOWN" : 0x02,
                "DUMP" : 0x0A, "USR1" : 0x0A,
                "USR2" : 0x0C, "DEBUG" : 0x0C,
                "TERM" : 0x0F, "HALT" : 0x0F
                }.get(sig,sig)
        self._sendAndRecv(MSG_TYPE.SIGNAL,struct.pack("B",sig))

    def map_address(self, kvList):
        """Given a list of (old-address,new-address), have Tor redirect
           streams from old-address to new-address.  Old-address can be in a
           special "dont-care" form of "0.0.0.0" or ".".
        """
        msg = [ "%s %s\n"%(k,v) for k,v in kvList ]
        tp, body = self._sendAndRecv(MSG_TYPE.MAPADDRESS,"".join(msg))
        return _parseKV(body)

    def extend_circuit(self, circid, hops):
        """Tell Tor to extend the circuit identified by 'circid' through the
           servers named in the list "hops".
        """
        msg = struct.pack("!L",long(circid)) + ",".join(hops) + "\0"
        tp, body = self._sendAndRecv(MSG_TYPE.EXTENDCIRCUIT,msg)
        if len(body) != 4:
            raise ProtocolError("Extendcircuit reply too short or long")
        return struct.unpack("!L",body)[0]

    def redirect_stream(self, streamid, newtarget):
        """Tell Tor to change the target address of the stream identified by
           'streamid' from its old value to 'newtarget'."""
        msg = struct.pack("!L",long(streamid)) + newtarget + "\0"
        self._sendAndRecv(MSG_TYPE.REDIRECTSTREAM,msg)

    def attach_stream(self, streamid, circid):
        """Tell Tor To attach stream 'streamid' to circuit 'circid'."""
        msg = struct.pack("!LL",long(streamid), long(circid))
        self._sendAndRecv(MSG_TYPE.ATTACHSTREAM,msg)

    def close_stream(self, streamid, reason=0, flags=()):
        """Close the stream 'streamid'. """
        msg = struct.pack("!LBB",long(streamid),reason,flags)
        self._sendAndRecv(MSG_TYPE.CLOSESTREAM,msg)

    def close_circuit(self, circid, flags=()):
        """Close the circuit 'circid'."""
        if "IFUNUSED" in flags:
            flags=1
        else:
            flags=0
        msg = struct.pack("!LB",long(circid),flags)
        self._sendAndRecv(MSG_TYPE.CLOSECIRCUIT,msg)

    def post_descriptor(self, descriptor):
        """Tell Tor about a new descriptor in 'descriptor'."""
        self._sendAndRecv(MSG_TYPE.POSTDESCRIPTOR,descriptor)
