home *** CD-ROM | disk | FTP | other *** search
/ PC World 2005 June / PCWorld_2005-06_cd.bin / software / vyzkuste / firewally / firewally.exe / framework-2.3.exe / rpc.py < prev    next >
Text File  |  2003-12-30  |  22KB  |  668 lines

  1. """RPC Implemention, originally written for the Python Idle IDE
  2.  
  3. For security reasons, GvR requested that Idle's Python execution server process
  4. connect to the Idle process, which listens for the connection.  Since Idle has
  5. has only one client per server, this was not a limitation.
  6.  
  7.    +---------------------------------+ +-------------+
  8.    | SocketServer.BaseRequestHandler | | SocketIO    |
  9.    +---------------------------------+ +-------------+
  10.                    ^                   | register()  |
  11.                    |                   | unregister()|
  12.                    |                   +-------------+
  13.                    |                      ^  ^
  14.                    |                      |  |
  15.                    | + -------------------+  |
  16.                    | |                       |
  17.    +-------------------------+        +-----------------+
  18.    | RPCHandler              |        | RPCClient       |
  19.    | [attribute of RPCServer]|        |                 |
  20.    +-------------------------+        +-----------------+
  21.  
  22. The RPCServer handler class is expected to provide register/unregister methods.
  23. RPCHandler inherits the mix-in class SocketIO, which provides these methods.
  24.  
  25. See the Idle run.main() docstring for further information on how this was
  26. accomplished in Idle.
  27.  
  28. """
  29.  
  30. import sys
  31. import os
  32. import socket
  33. import select
  34. import SocketServer
  35. import struct
  36. import cPickle as pickle
  37. import threading
  38. import Queue
  39. import traceback
  40. import copy_reg
  41. import types
  42. import marshal
  43.  
  44.  
  45. def unpickle_code(ms):
  46.     co = marshal.loads(ms)
  47.     assert isinstance(co, types.CodeType)
  48.     return co
  49.  
  50. def pickle_code(co):
  51.     assert isinstance(co, types.CodeType)
  52.     ms = marshal.dumps(co)
  53.     return unpickle_code, (ms,)
  54.  
  55. # XXX KBK 24Aug02 function pickling capability not used in Idle
  56. #  def unpickle_function(ms):
  57. #      return ms
  58.  
  59. #  def pickle_function(fn):
  60. #      assert isinstance(fn, type.FunctionType)
  61. #      return `fn`
  62.  
  63. copy_reg.pickle(types.CodeType, pickle_code, unpickle_code)
  64. # copy_reg.pickle(types.FunctionType, pickle_function, unpickle_function)
  65.  
  66. BUFSIZE = 8*1024
  67. LOCALHOST = '127.0.0.1'
  68.  
  69. class RPCServer(SocketServer.TCPServer):
  70.  
  71.     def __init__(self, addr, handlerclass=None):
  72.         if handlerclass is None:
  73.             handlerclass = RPCHandler
  74.         SocketServer.TCPServer.__init__(self, addr, handlerclass)
  75.  
  76.     def server_bind(self):
  77.         "Override TCPServer method, no bind() phase for connecting entity"
  78.         pass
  79.  
  80.     def server_activate(self):
  81.         """Override TCPServer method, connect() instead of listen()
  82.  
  83.         Due to the reversed connection, self.server_address is actually the
  84.         address of the Idle Client to which we are connecting.
  85.  
  86.         """
  87.         self.socket.connect(self.server_address)
  88.  
  89.     def get_request(self):
  90.         "Override TCPServer method, return already connected socket"
  91.         return self.socket, self.server_address
  92.  
  93.     def handle_error(self, request, client_address):
  94.         """Override TCPServer method
  95.  
  96.         Error message goes to __stderr__.  No error message if exiting
  97.         normally or socket raised EOF.  Other exceptions not handled in
  98.         server code will cause os._exit.
  99.  
  100.         """
  101.         try:
  102.             raise
  103.         except SystemExit:
  104.             raise
  105.         except:
  106.             erf = sys.__stderr__
  107.             print>>erf, '\n' + '-'*40
  108.             print>>erf, 'Unhandled server exception!'
  109.             print>>erf, 'Thread: %s' % threading.currentThread().getName()
  110.             print>>erf, 'Client Address: ', client_address
  111.             print>>erf, 'Request: ', repr(request)
  112.             traceback.print_exc(file=erf)
  113.             print>>erf, '\n*** Unrecoverable, server exiting!'
  114.             print>>erf, '-'*40
  115.             os._exit(0)
  116.  
  117. #----------------- end class RPCServer --------------------
  118.  
  119. objecttable = {}
  120. request_queue = Queue.Queue(0)
  121. response_queue = Queue.Queue(0)
  122.  
  123.  
  124. class SocketIO:
  125.  
  126.     nextseq = 0
  127.  
  128.     def __init__(self, sock, objtable=None, debugging=None):
  129.         self.sockthread = threading.currentThread()
  130.         if debugging is not None:
  131.             self.debugging = debugging
  132.         self.sock = sock
  133.         if objtable is None:
  134.             objtable = objecttable
  135.         self.objtable = objtable
  136.         self.responses = {}
  137.         self.cvars = {}
  138.  
  139.     def close(self):
  140.         sock = self.sock
  141.         self.sock = None
  142.         if sock is not None:
  143.             sock.close()
  144.  
  145.     def exithook(self):
  146.         "override for specific exit action"
  147.         os._exit()
  148.  
  149.     def debug(self, *args):
  150.         if not self.debugging:
  151.             return
  152.         s = self.location + " " + str(threading.currentThread().getName())
  153.         for a in args:
  154.             s = s + " " + str(a)
  155.         print>>sys.__stderr__, s
  156.  
  157.     def register(self, oid, object):
  158.         self.objtable[oid] = object
  159.  
  160.     def unregister(self, oid):
  161.         try:
  162.             del self.objtable[oid]
  163.         except KeyError:
  164.             pass
  165.  
  166.     def localcall(self, seq, request):
  167.         self.debug("localcall:", request)
  168.         try:
  169.             how, (oid, methodname, args, kwargs) = request
  170.         except TypeError:
  171.             return ("ERROR", "Bad request format")
  172.         if not self.objtable.has_key(oid):
  173.             return ("ERROR", "Unknown object id: %s" % `oid`)
  174.         obj = self.objtable[oid]
  175.         if methodname == "__methods__":
  176.             methods = {}
  177.             _getmethods(obj, methods)
  178.             return ("OK", methods)
  179.         if methodname == "__attributes__":
  180.             attributes = {}
  181.             _getattributes(obj, attributes)
  182.             return ("OK", attributes)
  183.         if not hasattr(obj, methodname):
  184.             return ("ERROR", "Unsupported method name: %s" % `methodname`)
  185.         method = getattr(obj, methodname)
  186.         try:
  187.             if how == 'CALL':
  188.                 ret = method(*args, **kwargs)
  189.                 if isinstance(ret, RemoteObject):
  190.                     ret = remoteref(ret)
  191.                 return ("OK", ret)
  192.             elif how == 'QUEUE':
  193.                 request_queue.put((seq, (method, args, kwargs)))
  194.                 return("QUEUED", None)
  195.             else:
  196.                 return ("ERROR", "Unsupported message type: %s" % how)
  197.         except SystemExit:
  198.             raise
  199.         except socket.error:
  200.             raise
  201.         except:
  202.             self.debug("localcall:EXCEPTION")
  203.             traceback.print_exc(file=sys.__stderr__)
  204.             return ("EXCEPTION", None)
  205.  
  206.     def remotecall(self, oid, methodname, args, kwargs):
  207.         self.debug("remotecall:asynccall: ", oid, methodname)
  208.         seq = self.asynccall(oid, methodname, args, kwargs)
  209.         return self.asyncreturn(seq)
  210.  
  211.     def remotequeue(self, oid, methodname, args, kwargs):
  212.         self.debug("remotequeue:asyncqueue: ", oid, methodname)
  213.         seq = self.asyncqueue(oid, methodname, args, kwargs)
  214.         return self.asyncreturn(seq)
  215.  
  216.     def asynccall(self, oid, methodname, args, kwargs):
  217.         request = ("CALL", (oid, methodname, args, kwargs))
  218.         seq = self.newseq()
  219.         if threading.currentThread() != self.sockthread:
  220.             cvar = threading.Condition()
  221.             self.cvars[seq] = cvar
  222.         self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
  223.         self.putmessage((seq, request))
  224.         return seq
  225.  
  226.     def asyncqueue(self, oid, methodname, args, kwargs):
  227.         request = ("QUEUE", (oid, methodname, args, kwargs))
  228.         seq = self.newseq()
  229.         if threading.currentThread() != self.sockthread:
  230.             cvar = threading.Condition()
  231.             self.cvars[seq] = cvar
  232.         self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs)
  233.         self.putmessage((seq, request))
  234.         return seq
  235.  
  236.     def asyncreturn(self, seq):
  237.         self.debug("asyncreturn:%d:call getresponse(): " % seq)
  238.         response = self.getresponse(seq, wait=0.05)
  239.         self.debug(("asyncreturn:%d:response: " % seq), response)
  240.         return self.decoderesponse(response)
  241.  
  242.     def decoderesponse(self, response):
  243.         how, what = response
  244.         if how == "OK":
  245.             return what
  246.         if how == "QUEUED":
  247.             return None
  248.         if how == "EXCEPTION":
  249.             self.debug("decoderesponse: EXCEPTION")
  250.             return None
  251.         if how == "EOF":
  252.             self.debug("decoderesponse: EOF")
  253.             self.decode_interrupthook()
  254.             return None
  255.         if how == "ERROR":
  256.             self.debug("decoderesponse: Internal ERROR:", what)
  257.             raise RuntimeError, what
  258.         raise SystemError, (how, what)
  259.  
  260.     def decode_interrupthook(self):
  261.         ""
  262.         raise EOFError
  263.  
  264.     def mainloop(self):
  265.         """Listen on socket until I/O not ready or EOF
  266.  
  267.         pollresponse() will loop looking for seq number None, which
  268.         never comes, and exit on EOFError.
  269.  
  270.         """
  271.         try:
  272.             self.getresponse(myseq=None, wait=0.05)
  273.         except EOFError:
  274.             self.debug("mainloop:return")
  275.             return
  276.  
  277.     def getresponse(self, myseq, wait):
  278.         response = self._getresponse(myseq, wait)
  279.         if response is not None:
  280.             how, what = response
  281.             if how == "OK":
  282.                 response = how, self._proxify(what)
  283.         return response
  284.  
  285.     def _proxify(self, obj):
  286.         if isinstance(obj, RemoteProxy):
  287.             return RPCProxy(self, obj.oid)
  288.         if isinstance(obj, types.ListType):
  289.             return map(self._proxify, obj)
  290.         # XXX Check for other types -- not currently needed
  291.         return obj
  292.  
  293.     def _getresponse(self, myseq, wait):
  294.         self.debug("_getresponse:myseq:", myseq)
  295.         if threading.currentThread() is self.sockthread:
  296.             # this thread does all reading of requests or responses
  297.             while 1:
  298.                 response = self.pollresponse(myseq, wait)
  299.                 if response is not None:
  300.                     return response
  301.         else:
  302.             # wait for notification from socket handling thread
  303.             cvar = self.cvars[myseq]
  304.             cvar.acquire()
  305.             while not self.responses.has_key(myseq):
  306.                 cvar.wait()
  307.             response = self.responses[myseq]
  308.             self.debug("_getresponse:%s: thread woke up: response: %s" %
  309.                        (myseq, response))
  310.             del self.responses[myseq]
  311.             del self.cvars[myseq]
  312.             cvar.release()
  313.             return response
  314.  
  315.     def newseq(self):
  316.         self.nextseq = seq = self.nextseq + 2
  317.         return seq
  318.  
  319.     def putmessage(self, message):
  320.         self.debug("putmessage:%d:" % message[0])
  321.         try:
  322.             s = pickle.dumps(message)
  323.         except pickle.UnpicklingError:
  324.             print >>sys.__stderr__, "Cannot pickle:", `message`
  325.             raise
  326.         s = struct.pack("<i", len(s)) + s
  327.         while len(s) > 0:
  328.             try:
  329.                 n = self.sock.send(s)
  330.             except (AttributeError, socket.error):
  331.                 # socket was closed
  332.                 raise IOError
  333.             else:
  334.                 s = s[n:]
  335.  
  336.     def ioready(self, wait):
  337.         r, w, x = select.select([self.sock.fileno()], [], [], wait)
  338.         return len(r)
  339.  
  340.     buffer = ""
  341.     bufneed = 4
  342.     bufstate = 0 # meaning: 0 => reading count; 1 => reading data
  343.  
  344.     def pollpacket(self, wait):
  345.         self._stage0()
  346.         if len(self.buffer) < self.bufneed:
  347.             if not self.ioready(wait):
  348.                 return None
  349.             try:
  350.                 s = self.sock.recv(BUFSIZE)
  351.             except socket.error:
  352.                 raise EOFError
  353.             if len(s) == 0:
  354.                 raise EOFError
  355.             self.buffer += s
  356.             self._stage0()
  357.         return self._stage1()
  358.  
  359.     def _stage0(self):
  360.         if self.bufstate == 0 and len(self.buffer) >= 4:
  361.             s = self.buffer[:4]
  362.             self.buffer = self.buffer[4:]
  363.             self.bufneed = struct.unpack("<i", s)[0]
  364.             self.bufstate = 1
  365.  
  366.     def _stage1(self):
  367.         if self.bufstate == 1 and len(self.buffer) >= self.bufneed:
  368.             packet = self.buffer[:self.bufneed]
  369.             self.buffer = self.buffer[self.bufneed:]
  370.             self.bufneed = 4
  371.             self.bufstate = 0
  372.             return packet
  373.  
  374.     def pollmessage(self, wait):
  375.         packet = self.pollpacket(wait)
  376.         if packet is None:
  377.             return None
  378.         try:
  379.             message = pickle.loads(packet)
  380.         except:
  381.             print >>sys.__stderr__, "-----------------------"
  382.             print >>sys.__stderr__, "cannot unpickle packet:", `packet`
  383.             traceback.print_stack(file=sys.__stderr__)
  384.             print >>sys.__stderr__, "-----------------------"
  385.             raise
  386.         return message
  387.  
  388.     def pollresponse(self, myseq, wait):
  389.         """Handle messages received on the socket.
  390.  
  391.         Some messages received may be asynchronous 'call' or 'queue' requests,
  392.         and some may be responses for other threads.
  393.  
  394.         'call' requests are passed to self.localcall() with the expectation of
  395.         immediate execution, during which time the socket is not serviced.
  396.  
  397.         'queue' requests are used for tasks (which may block or hang) to be
  398.         processed in a different thread.  These requests are fed into
  399.         request_queue by self.localcall().  Responses to queued requests are
  400.         taken from response_queue and sent across the link with the associated
  401.         sequence numbers.  Messages in the queues are (sequence_number,
  402.         request/response) tuples and code using this module removing messages
  403.         from the request_queue is responsible for returning the correct
  404.         sequence number in the response_queue.
  405.  
  406.         pollresponse() will loop until a response message with the myseq
  407.         sequence number is received, and will save other responses in
  408.         self.responses and notify the owning thread.
  409.  
  410.         """
  411.         while 1:
  412.             # send queued response if there is one available
  413.             try:
  414.                 qmsg = response_queue.get(0)
  415.             except Queue.Empty:
  416.                 pass
  417.             else:
  418.                 seq, response = qmsg
  419.                 message = (seq, ('OK', response))
  420.                 self.putmessage(message)
  421.             # poll for message on link
  422.             try:
  423.                 message = self.pollmessage(wait)
  424.                 if message is None:  # socket not ready
  425.                     return None
  426.             except EOFError:
  427.                 self.handle_EOF()
  428.                 return None
  429.             except AttributeError:
  430.                 return None
  431.             seq, resq = message
  432.             how = resq[0]
  433.             self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
  434.             # process or queue a request
  435.             if how in ("CALL", "QUEUE"):
  436.                 self.debug("pollresponse:%d:localcall:call:" % seq)
  437.                 response = self.localcall(seq, resq)
  438.                 self.debug("pollresponse:%d:localcall:response:%s"
  439.                            % (seq, response))
  440.                 if how == "CALL":
  441.                     self.putmessage((seq, response))
  442.                 elif how == "QUEUE":
  443.                     # don't acknowledge the 'queue' request!
  444.                     pass
  445.                 continue
  446.             # return if completed message transaction
  447.             elif seq == myseq:
  448.                 return resq
  449.             # must be a response for a different thread:
  450.             else:
  451.                 cv = self.cvars.get(seq, None)
  452.                 # response involving unknown sequence number is discarded,
  453.                 # probably intended for prior incarnation of server
  454.                 if cv is not None:
  455.                     cv.acquire()
  456.                     self.responses[seq] = resq
  457.                     cv.notify()
  458.                     cv.release()
  459.                 continue
  460.  
  461.     def handle_EOF(self):
  462.         "action taken upon link being closed by peer"
  463.         self.EOFhook()
  464.         self.debug("handle_EOF")
  465.         for key in self.cvars:
  466.             cv = self.cvars[key]
  467.             cv.acquire()
  468.             self.responses[key] = ('EOF', None)
  469.             cv.notify()
  470.             cv.release()
  471.         # call our (possibly overridden) exit function
  472.         self.exithook()
  473.  
  474.     def EOFhook(self):
  475.         "Classes using rpc client/server can override to augment EOF action"
  476.         pass
  477.  
  478. #----------------- end class SocketIO --------------------
  479.  
  480. class RemoteObject:
  481.     # Token mix-in class
  482.     pass
  483.  
  484. def remoteref(obj):
  485.     oid = id(obj)
  486.     objecttable[oid] = obj
  487.     return RemoteProxy(oid)
  488.  
  489. class RemoteProxy:
  490.  
  491.     def __init__(self, oid):
  492.         self.oid = oid
  493.  
  494. class RPCHandler(SocketServer.BaseRequestHandler, SocketIO):
  495.  
  496.     debugging = False
  497.     location = "#S"  # Server
  498.  
  499.     def __init__(self, sock, addr, svr):
  500.         svr.current_handler = self ## cgt xxx
  501.         SocketIO.__init__(self, sock)
  502.         SocketServer.BaseRequestHandler.__init__(self, sock, addr, svr)
  503.  
  504.     def handle(self):
  505.         "handle() method required by SocketServer"
  506.         self.mainloop()
  507.  
  508.     def get_remote_proxy(self, oid):
  509.         return RPCProxy(self, oid)
  510.  
  511. class RPCClient(SocketIO):
  512.  
  513.     debugging = False
  514.     location = "#C"  # Client
  515.  
  516.     nextseq = 1 # Requests coming from the client are odd numbered
  517.  
  518.     def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
  519.         self.listening_sock = socket.socket(family, type)
  520.         self.listening_sock.setsockopt(socket.SOL_SOCKET,
  521.                                        socket.SO_REUSEADDR, 1)
  522.         self.listening_sock.bind(address)
  523.         self.listening_sock.listen(1)
  524.  
  525.     def accept(self):
  526.         working_sock, address = self.listening_sock.accept()
  527.         if self.debugging:
  528.             print>>sys.__stderr__, "****** Connection request from ", address
  529.         if address[0] == LOCALHOST:
  530.             SocketIO.__init__(self, working_sock)
  531.         else:
  532.             print>>sys.__stderr__, "** Invalid host: ", address
  533.             raise socket.error
  534.  
  535.     def get_remote_proxy(self, oid):
  536.         return RPCProxy(self, oid)
  537.  
  538. class RPCProxy:
  539.  
  540.     __methods = None
  541.     __attributes = None
  542.  
  543.     def __init__(self, sockio, oid):
  544.         self.sockio = sockio
  545.         self.oid = oid
  546.  
  547.     def __getattr__(self, name):
  548.         if self.__methods is None:
  549.             self.__getmethods()
  550.         if self.__methods.get(name):
  551.             return MethodProxy(self.sockio, self.oid, name)
  552.         if self.__attributes is None:
  553.             self.__getattributes()
  554.         if not self.__attributes.has_key(name):
  555.             raise AttributeError, name
  556.  
  557.     def __getattributes(self):
  558.         self.__attributes = self.sockio.remotecall(self.oid,
  559.                                                 "__attributes__", (), {})
  560.  
  561.     def __getmethods(self):
  562.         self.__methods = self.sockio.remotecall(self.oid,
  563.                                                 "__methods__", (), {})
  564.  
  565. def _getmethods(obj, methods):
  566.     # Helper to get a list of methods from an object
  567.     # Adds names to dictionary argument 'methods'
  568.     for name in dir(obj):
  569.         attr = getattr(obj, name)
  570.         if callable(attr):
  571.             methods[name] = 1
  572.     if type(obj) == types.InstanceType:
  573.         _getmethods(obj.__class__, methods)
  574.     if type(obj) == types.ClassType:
  575.         for super in obj.__bases__:
  576.             _getmethods(super, methods)
  577.  
  578. def _getattributes(obj, attributes):
  579.     for name in dir(obj):
  580.         attr = getattr(obj, name)
  581.         if not callable(attr):
  582.             attributes[name] = 1
  583.  
  584. class MethodProxy:
  585.  
  586.     def __init__(self, sockio, oid, name):
  587.         self.sockio = sockio
  588.         self.oid = oid
  589.         self.name = name
  590.  
  591.     def __call__(self, *args, **kwargs):
  592.         value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
  593.         return value
  594.  
  595. #
  596. # Self Test
  597. #
  598.  
  599. def testServer(addr):
  600.     # XXX 25 Jul 02 KBK needs update to use rpc.py register/unregister methods
  601.     class RemotePerson:
  602.         def __init__(self,name):
  603.             self.name = name
  604.         def greet(self, name):
  605.             print "(someone called greet)"
  606.             print "Hello %s, I am %s." % (name, self.name)
  607.             print
  608.         def getName(self):
  609.             print "(someone called getName)"
  610.             print
  611.             return self.name
  612.         def greet_this_guy(self, name):
  613.             print "(someone called greet_this_guy)"
  614.             print "About to greet %s ..." % name
  615.             remote_guy = self.server.current_handler.get_remote_proxy(name)
  616.             remote_guy.greet("Thomas Edison")
  617.             print "Done."
  618.             print
  619.  
  620.     person = RemotePerson("Thomas Edison")
  621.     svr = RPCServer(addr)
  622.     svr.register('thomas', person)
  623.     person.server = svr # only required if callbacks are used
  624.  
  625.     # svr.serve_forever()
  626.     svr.handle_request()  # process once only
  627.  
  628. def testClient(addr):
  629.     "demonstrates RPC Client"
  630.     # XXX 25 Jul 02 KBK needs update to use rpc.py register/unregister methods
  631.     import time
  632.     clt=RPCClient(addr)
  633.     thomas = clt.get_remote_proxy("thomas")
  634.     print "The remote person's name is ..."
  635.     print thomas.getName()
  636.     # print clt.remotecall("thomas", "getName", (), {})
  637.     print
  638.     time.sleep(1)
  639.     print "Getting remote thomas to say hi..."
  640.     thomas.greet("Alexander Bell")
  641.     #clt.remotecall("thomas","greet",("Alexander Bell",), {})
  642.     print "Done."
  643.     print
  644.     time.sleep(2)
  645.     # demonstrates remote server calling local instance
  646.     class LocalPerson:
  647.         def __init__(self,name):
  648.             self.name = name
  649.         def greet(self, name):
  650.             print "You've greeted me!"
  651.         def getName(self):
  652.             return self.name
  653.     person = LocalPerson("Alexander Bell")
  654.     clt.register("alexander",person)
  655.     thomas.greet_this_guy("alexander")
  656.     # clt.remotecall("thomas","greet_this_guy",("alexander",), {})
  657.  
  658. def test():
  659.     addr=(LOCALHOST, 8833)
  660.     if len(sys.argv) == 2:
  661.         if sys.argv[1]=='-server':
  662.             testServer(addr)
  663.             return
  664.     testClient(addr)
  665.  
  666. if __name__ == '__main__':
  667.     test()
  668.