home *** CD-ROM | disk | FTP | other *** search
/ PC World 2005 June / PCWorld_2005-06_cd.bin / software / vyzkuste / firewally / firewally.exe / framework-2.3.exe / pickletester.py < prev    next >
Text File  |  2003-12-30  |  30KB  |  959 lines

  1. import unittest
  2. import pickle
  3. import cPickle
  4. import pickletools
  5. import copy_reg
  6.  
  7. from test.test_support import TestFailed, have_unicode, TESTFN
  8.  
  9. # Tests that try a number of pickle protocols should have a
  10. #     for proto in protocols:
  11. # kind of outer loop.
  12. assert pickle.HIGHEST_PROTOCOL == cPickle.HIGHEST_PROTOCOL == 2
  13. protocols = range(pickle.HIGHEST_PROTOCOL + 1)
  14.  
  15.  
  16. # Return True if opcode code appears in the pickle, else False.
  17. def opcode_in_pickle(code, pickle):
  18.     for op, dummy, dummy in pickletools.genops(pickle):
  19.         if op.code == code:
  20.             return True
  21.     return False
  22.  
  23. # Return the number of times opcode code appears in pickle.
  24. def count_opcode(code, pickle):
  25.     n = 0
  26.     for op, dummy, dummy in pickletools.genops(pickle):
  27.         if op.code == code:
  28.             n += 1
  29.     return n
  30.  
  31. # We can't very well test the extension registry without putting known stuff
  32. # in it, but we have to be careful to restore its original state.  Code
  33. # should do this:
  34. #
  35. #     e = ExtensionSaver(extension_code)
  36. #     try:
  37. #         fiddle w/ the extension registry's stuff for extension_code
  38. #     finally:
  39. #         e.restore()
  40.  
  41. class ExtensionSaver:
  42.     # Remember current registration for code (if any), and remove it (if
  43.     # there is one).
  44.     def __init__(self, code):
  45.         self.code = code
  46.         if code in copy_reg._inverted_registry:
  47.             self.pair = copy_reg._inverted_registry[code]
  48.             copy_reg.remove_extension(self.pair[0], self.pair[1], code)
  49.         else:
  50.             self.pair = None
  51.  
  52.     # Restore previous registration for code.
  53.     def restore(self):
  54.         code = self.code
  55.         curpair = copy_reg._inverted_registry.get(code)
  56.         if curpair is not None:
  57.             copy_reg.remove_extension(curpair[0], curpair[1], code)
  58.         pair = self.pair
  59.         if pair is not None:
  60.             copy_reg.add_extension(pair[0], pair[1], code)
  61.  
  62. class C:
  63.     def __cmp__(self, other):
  64.         return cmp(self.__dict__, other.__dict__)
  65.  
  66. import __main__
  67. __main__.C = C
  68. C.__module__ = "__main__"
  69.  
  70. class myint(int):
  71.     def __init__(self, x):
  72.         self.str = str(x)
  73.  
  74. class initarg(C):
  75.  
  76.     def __init__(self, a, b):
  77.         self.a = a
  78.         self.b = b
  79.  
  80.     def __getinitargs__(self):
  81.         return self.a, self.b
  82.  
  83. class metaclass(type):
  84.     pass
  85.  
  86. class use_metaclass(object):
  87.     __metaclass__ = metaclass
  88.  
  89. # DATA0 .. DATA2 are the pickles we expect under the various protocols, for
  90. # the object returned by create_data().
  91.  
  92. # break into multiple strings to avoid confusing font-lock-mode
  93. DATA0 = """(lp1
  94. I0
  95. aL1L
  96. aF2
  97. ac__builtin__
  98. complex
  99. p2
  100. """ + \
  101. """(F3
  102. F0
  103. tRp3
  104. aI1
  105. aI-1
  106. aI255
  107. aI-255
  108. aI-256
  109. aI65535
  110. aI-65535
  111. aI-65536
  112. aI2147483647
  113. aI-2147483647
  114. aI-2147483648
  115. a""" + \
  116. """(S'abc'
  117. p4
  118. g4
  119. """ + \
  120. """(i__main__
  121. C
  122. p5
  123. """ + \
  124. """(dp6
  125. S'foo'
  126. p7
  127. I1
  128. sS'bar'
  129. p8
  130. I2
  131. sbg5
  132. tp9
  133. ag9
  134. aI5
  135. a.
  136. """
  137.  
  138. # Disassembly of DATA0.
  139. DATA0_DIS = """\
  140.     0: (    MARK
  141.     1: l        LIST       (MARK at 0)
  142.     2: p    PUT        1
  143.     5: I    INT        0
  144.     8: a    APPEND
  145.     9: L    LONG       1L
  146.    13: a    APPEND
  147.    14: F    FLOAT      2.0
  148.    17: a    APPEND
  149.    18: c    GLOBAL     '__builtin__ complex'
  150.    39: p    PUT        2
  151.    42: (    MARK
  152.    43: F        FLOAT      3.0
  153.    46: F        FLOAT      0.0
  154.    49: t        TUPLE      (MARK at 42)
  155.    50: R    REDUCE
  156.    51: p    PUT        3
  157.    54: a    APPEND
  158.    55: I    INT        1
  159.    58: a    APPEND
  160.    59: I    INT        -1
  161.    63: a    APPEND
  162.    64: I    INT        255
  163.    69: a    APPEND
  164.    70: I    INT        -255
  165.    76: a    APPEND
  166.    77: I    INT        -256
  167.    83: a    APPEND
  168.    84: I    INT        65535
  169.    91: a    APPEND
  170.    92: I    INT        -65535
  171.   100: a    APPEND
  172.   101: I    INT        -65536
  173.   109: a    APPEND
  174.   110: I    INT        2147483647
  175.   122: a    APPEND
  176.   123: I    INT        -2147483647
  177.   136: a    APPEND
  178.   137: I    INT        -2147483648
  179.   150: a    APPEND
  180.   151: (    MARK
  181.   152: S        STRING     'abc'
  182.   159: p        PUT        4
  183.   162: g        GET        4
  184.   165: (        MARK
  185.   166: i            INST       '__main__ C' (MARK at 165)
  186.   178: p        PUT        5
  187.   181: (        MARK
  188.   182: d            DICT       (MARK at 181)
  189.   183: p        PUT        6
  190.   186: S        STRING     'foo'
  191.   193: p        PUT        7
  192.   196: I        INT        1
  193.   199: s        SETITEM
  194.   200: S        STRING     'bar'
  195.   207: p        PUT        8
  196.   210: I        INT        2
  197.   213: s        SETITEM
  198.   214: b        BUILD
  199.   215: g        GET        5
  200.   218: t        TUPLE      (MARK at 151)
  201.   219: p    PUT        9
  202.   222: a    APPEND
  203.   223: g    GET        9
  204.   226: a    APPEND
  205.   227: I    INT        5
  206.   230: a    APPEND
  207.   231: .    STOP
  208. highest protocol among opcodes = 0
  209. """
  210.  
  211. DATA1 = (']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00'
  212.          'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00'
  213.          '\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff'
  214.          '\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff'
  215.          'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00'
  216.          '\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n'
  217.          'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh'
  218.          '\x06tq\nh\nK\x05e.'
  219.         )
  220.  
  221. # Disassembly of DATA1.
  222. DATA1_DIS = """\
  223.     0: ]    EMPTY_LIST
  224.     1: q    BINPUT     1
  225.     3: (    MARK
  226.     4: K        BININT1    0
  227.     6: L        LONG       1L
  228.    10: G        BINFLOAT   2.0
  229.    19: c        GLOBAL     '__builtin__ complex'
  230.    40: q        BINPUT     2
  231.    42: (        MARK
  232.    43: G            BINFLOAT   3.0
  233.    52: G            BINFLOAT   0.0
  234.    61: t            TUPLE      (MARK at 42)
  235.    62: R        REDUCE
  236.    63: q        BINPUT     3
  237.    65: K        BININT1    1
  238.    67: J        BININT     -1
  239.    72: K        BININT1    255
  240.    74: J        BININT     -255
  241.    79: J        BININT     -256
  242.    84: M        BININT2    65535
  243.    87: J        BININT     -65535
  244.    92: J        BININT     -65536
  245.    97: J        BININT     2147483647
  246.   102: J        BININT     -2147483647
  247.   107: J        BININT     -2147483648
  248.   112: (        MARK
  249.   113: U            SHORT_BINSTRING 'abc'
  250.   118: q            BINPUT     4
  251.   120: h            BINGET     4
  252.   122: (            MARK
  253.   123: c                GLOBAL     '__main__ C'
  254.   135: q                BINPUT     5
  255.   137: o                OBJ        (MARK at 122)
  256.   138: q            BINPUT     6
  257.   140: }            EMPTY_DICT
  258.   141: q            BINPUT     7
  259.   143: (            MARK
  260.   144: U                SHORT_BINSTRING 'foo'
  261.   149: q                BINPUT     8
  262.   151: K                BININT1    1
  263.   153: U                SHORT_BINSTRING 'bar'
  264.   158: q                BINPUT     9
  265.   160: K                BININT1    2
  266.   162: u                SETITEMS   (MARK at 143)
  267.   163: b            BUILD
  268.   164: h            BINGET     6
  269.   166: t            TUPLE      (MARK at 112)
  270.   167: q        BINPUT     10
  271.   169: h        BINGET     10
  272.   171: K        BININT1    5
  273.   173: e        APPENDS    (MARK at 3)
  274.   174: .    STOP
  275. highest protocol among opcodes = 1
  276. """
  277.  
  278. DATA2 = ('\x80\x02]q\x01(K\x00\x8a\x01\x01G@\x00\x00\x00\x00\x00\x00\x00'
  279.          'c__builtin__\ncomplex\nq\x02G@\x08\x00\x00\x00\x00\x00\x00G\x00'
  280.          '\x00\x00\x00\x00\x00\x00\x00\x86Rq\x03K\x01J\xff\xff\xff\xffK'
  281.          '\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xff'
  282.          'J\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00'
  283.          '\x80(U\x03abcq\x04h\x04(c__main__\nC\nq\x05oq\x06}q\x07(U\x03foo'
  284.          'q\x08K\x01U\x03barq\tK\x02ubh\x06tq\nh\nK\x05e.')
  285.  
  286. # Disassembly of DATA2.
  287. DATA2_DIS = """\
  288.     0: \x80 PROTO      2
  289.     2: ]    EMPTY_LIST
  290.     3: q    BINPUT     1
  291.     5: (    MARK
  292.     6: K        BININT1    0
  293.     8: \x8a     LONG1      1L
  294.    11: G        BINFLOAT   2.0
  295.    20: c        GLOBAL     '__builtin__ complex'
  296.    41: q        BINPUT     2
  297.    43: G        BINFLOAT   3.0
  298.    52: G        BINFLOAT   0.0
  299.    61: \x86     TUPLE2
  300.    62: R        REDUCE
  301.    63: q        BINPUT     3
  302.    65: K        BININT1    1
  303.    67: J        BININT     -1
  304.    72: K        BININT1    255
  305.    74: J        BININT     -255
  306.    79: J        BININT     -256
  307.    84: M        BININT2    65535
  308.    87: J        BININT     -65535
  309.    92: J        BININT     -65536
  310.    97: J        BININT     2147483647
  311.   102: J        BININT     -2147483647
  312.   107: J        BININT     -2147483648
  313.   112: (        MARK
  314.   113: U            SHORT_BINSTRING 'abc'
  315.   118: q            BINPUT     4
  316.   120: h            BINGET     4
  317.   122: (            MARK
  318.   123: c                GLOBAL     '__main__ C'
  319.   135: q                BINPUT     5
  320.   137: o                OBJ        (MARK at 122)
  321.   138: q            BINPUT     6
  322.   140: }            EMPTY_DICT
  323.   141: q            BINPUT     7
  324.   143: (            MARK
  325.   144: U                SHORT_BINSTRING 'foo'
  326.   149: q                BINPUT     8
  327.   151: K                BININT1    1
  328.   153: U                SHORT_BINSTRING 'bar'
  329.   158: q                BINPUT     9
  330.   160: K                BININT1    2
  331.   162: u                SETITEMS   (MARK at 143)
  332.   163: b            BUILD
  333.   164: h            BINGET     6
  334.   166: t            TUPLE      (MARK at 112)
  335.   167: q        BINPUT     10
  336.   169: h        BINGET     10
  337.   171: K        BININT1    5
  338.   173: e        APPENDS    (MARK at 5)
  339.   174: .    STOP
  340. highest protocol among opcodes = 2
  341. """
  342.  
  343. def create_data():
  344.     c = C()
  345.     c.foo = 1
  346.     c.bar = 2
  347.     x = [0, 1L, 2.0, 3.0+0j]
  348.     # Append some integer test cases at cPickle.c's internal size
  349.     # cutoffs.
  350.     uint1max = 0xff
  351.     uint2max = 0xffff
  352.     int4max = 0x7fffffff
  353.     x.extend([1, -1,
  354.               uint1max, -uint1max, -uint1max-1,
  355.               uint2max, -uint2max, -uint2max-1,
  356.                int4max,  -int4max,  -int4max-1])
  357.     y = ('abc', 'abc', c, c)
  358.     x.append(y)
  359.     x.append(y)
  360.     x.append(5)
  361.     return x
  362.  
  363. class AbstractPickleTests(unittest.TestCase):
  364.     # Subclass must define self.dumps, self.loads, self.error.
  365.  
  366.     _testdata = create_data()
  367.  
  368.     def setUp(self):
  369.         pass
  370.  
  371.     def test_misc(self):
  372.         # test various datatypes not tested by testdata
  373.         for proto in protocols:
  374.             x = myint(4)
  375.             s = self.dumps(x, proto)
  376.             y = self.loads(s)
  377.             self.assertEqual(x, y)
  378.  
  379.             x = (1, ())
  380.             s = self.dumps(x, proto)
  381.             y = self.loads(s)
  382.             self.assertEqual(x, y)
  383.  
  384.             x = initarg(1, x)
  385.             s = self.dumps(x, proto)
  386.             y = self.loads(s)
  387.             self.assertEqual(x, y)
  388.  
  389.         # XXX test __reduce__ protocol?
  390.  
  391.     def test_roundtrip_equality(self):
  392.         expected = self._testdata
  393.         for proto in protocols:
  394.             s = self.dumps(expected, proto)
  395.             got = self.loads(s)
  396.             self.assertEqual(expected, got)
  397.  
  398.     def test_load_from_canned_string(self):
  399.         expected = self._testdata
  400.         for canned in DATA0, DATA1, DATA2:
  401.             got = self.loads(canned)
  402.             self.assertEqual(expected, got)
  403.  
  404.     # There are gratuitous differences between pickles produced by
  405.     # pickle and cPickle, largely because cPickle starts PUT indices at
  406.     # 1 and pickle starts them at 0.  See XXX comment in cPickle's put2() --
  407.     # there's a comment with an exclamation point there whose meaning
  408.     # is a mystery.  cPickle also suppresses PUT for objects with a refcount
  409.     # of 1.
  410.     def dont_test_disassembly(self):
  411.         from cStringIO import StringIO
  412.         from pickletools import dis
  413.  
  414.         for proto, expected in (0, DATA0_DIS), (1, DATA1_DIS):
  415.             s = self.dumps(self._testdata, proto)
  416.             filelike = StringIO()
  417.             dis(s, out=filelike)
  418.             got = filelike.getvalue()
  419.             self.assertEqual(expected, got)
  420.  
  421.     def test_recursive_list(self):
  422.         l = []
  423.         l.append(l)
  424.         for proto in protocols:
  425.             s = self.dumps(l, proto)
  426.             x = self.loads(s)
  427.             self.assertEqual(x, l)
  428.             self.assertEqual(x, x[0])
  429.             self.assertEqual(id(x), id(x[0]))
  430.  
  431.     def test_recursive_dict(self):
  432.         d = {}
  433.         d[1] = d
  434.         for proto in protocols:
  435.             s = self.dumps(d, proto)
  436.             x = self.loads(s)
  437.             self.assertEqual(x, d)
  438.             self.assertEqual(x[1], x)
  439.             self.assertEqual(id(x[1]), id(x))
  440.  
  441.     def test_recursive_inst(self):
  442.         i = C()
  443.         i.attr = i
  444.         for proto in protocols:
  445.             s = self.dumps(i, 2)
  446.             x = self.loads(s)
  447.             self.assertEqual(x, i)
  448.             self.assertEqual(x.attr, x)
  449.             self.assertEqual(id(x.attr), id(x))
  450.  
  451.     def test_recursive_multi(self):
  452.         l = []
  453.         d = {1:l}
  454.         i = C()
  455.         i.attr = d
  456.         l.append(i)
  457.         for proto in protocols:
  458.             s = self.dumps(l, proto)
  459.             x = self.loads(s)
  460.             self.assertEqual(x, l)
  461.             self.assertEqual(x[0], i)
  462.             self.assertEqual(x[0].attr, d)
  463.             self.assertEqual(x[0].attr[1], x)
  464.             self.assertEqual(x[0].attr[1][0], i)
  465.             self.assertEqual(x[0].attr[1][0].attr, d)
  466.  
  467.     def test_garyp(self):
  468.         self.assertRaises(self.error, self.loads, 'garyp')
  469.  
  470.     def test_insecure_strings(self):
  471.         insecure = ["abc", "2 + 2", # not quoted
  472.                     #"'abc' + 'def'", # not a single quoted string
  473.                     "'abc", # quote is not closed
  474.                     "'abc\"", # open quote and close quote don't match
  475.                     "'abc'   ?", # junk after close quote
  476.                     "'\\'", # trailing backslash
  477.                     # some tests of the quoting rules
  478.                     #"'abc\"\''",
  479.                     #"'\\\\a\'\'\'\\\'\\\\\''",
  480.                     ]
  481.         for s in insecure:
  482.             buf = "S" + s + "\012p0\012."
  483.             self.assertRaises(ValueError, self.loads, buf)
  484.  
  485.     if have_unicode:
  486.         def test_unicode(self):
  487.             endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'),
  488.                         unicode('<\n>'),  unicode('<\\>')]
  489.             for proto in protocols:
  490.                 for u in endcases:
  491.                     p = self.dumps(u, proto)
  492.                     u2 = self.loads(p)
  493.                     self.assertEqual(u2, u)
  494.  
  495.     def test_ints(self):
  496.         import sys
  497.         for proto in protocols:
  498.             n = sys.maxint
  499.             while n:
  500.                 for expected in (-n, n):
  501.                     s = self.dumps(expected, proto)
  502.                     n2 = self.loads(s)
  503.                     self.assertEqual(expected, n2)
  504.                 n = n >> 1
  505.  
  506.     def test_maxint64(self):
  507.         maxint64 = (1L << 63) - 1
  508.         data = 'I' + str(maxint64) + '\n.'
  509.         got = self.loads(data)
  510.         self.assertEqual(got, maxint64)
  511.  
  512.         # Try too with a bogus literal.
  513.         data = 'I' + str(maxint64) + 'JUNK\n.'
  514.         self.assertRaises(ValueError, self.loads, data)
  515.  
  516.     def test_long(self):
  517.         for proto in protocols:
  518.             # 256 bytes is where LONG4 begins.
  519.             for nbits in 1, 8, 8*254, 8*255, 8*256, 8*257:
  520.                 nbase = 1L << nbits
  521.                 for npos in nbase-1, nbase, nbase+1:
  522.                     for n in npos, -npos:
  523.                         pickle = self.dumps(n, proto)
  524.                         got = self.loads(pickle)
  525.                         self.assertEqual(n, got)
  526.         # Try a monster.  This is quadratic-time in protos 0 & 1, so don't
  527.         # bother with those.
  528.         nbase = long("deadbeeffeedface", 16)
  529.         nbase += nbase << 1000000
  530.         for n in nbase, -nbase:
  531.             p = self.dumps(n, 2)
  532.             got = self.loads(p)
  533.             self.assertEqual(n, got)
  534.  
  535.     def test_reduce(self):
  536.         pass
  537.  
  538.     def test_getinitargs(self):
  539.         pass
  540.  
  541.     def test_metaclass(self):
  542.         a = use_metaclass()
  543.         for proto in protocols:
  544.             s = self.dumps(a, proto)
  545.             b = self.loads(s)
  546.             self.assertEqual(a.__class__, b.__class__)
  547.  
  548.     def test_structseq(self):
  549.         import time
  550.         import os
  551.  
  552.         t = time.localtime()
  553.         for proto in protocols:
  554.             s = self.dumps(t, proto)
  555.             u = self.loads(s)
  556.             self.assertEqual(t, u)
  557.             if hasattr(os, "stat"):
  558.                 t = os.stat(os.curdir)
  559.                 s = self.dumps(t, proto)
  560.                 u = self.loads(s)
  561.                 self.assertEqual(t, u)
  562.             if hasattr(os, "statvfs"):
  563.                 t = os.statvfs(os.curdir)
  564.                 s = self.dumps(t, proto)
  565.                 u = self.loads(s)
  566.                 self.assertEqual(t, u)
  567.  
  568.     # Tests for protocol 2
  569.  
  570.     def test_proto(self):
  571.         build_none = pickle.NONE + pickle.STOP
  572.         for proto in protocols:
  573.             expected = build_none
  574.             if proto >= 2:
  575.                 expected = pickle.PROTO + chr(proto) + expected
  576.             p = self.dumps(None, proto)
  577.             self.assertEqual(p, expected)
  578.  
  579.         oob = protocols[-1] + 1     # a future protocol
  580.         badpickle = pickle.PROTO + chr(oob) + build_none
  581.         try:
  582.             self.loads(badpickle)
  583.         except ValueError, detail:
  584.             self.failUnless(str(detail).startswith(
  585.                                             "unsupported pickle protocol"))
  586.         else:
  587.             self.fail("expected bad protocol number to raise ValueError")
  588.  
  589.     def test_long1(self):
  590.         x = 12345678910111213141516178920L
  591.         for proto in protocols:
  592.             s = self.dumps(x, proto)
  593.             y = self.loads(s)
  594.             self.assertEqual(x, y)
  595.             self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2)
  596.  
  597.     def test_long4(self):
  598.         x = 12345678910111213141516178920L << (256*8)
  599.         for proto in protocols:
  600.             s = self.dumps(x, proto)
  601.             y = self.loads(s)
  602.             self.assertEqual(x, y)
  603.             self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2)
  604.  
  605.     def test_short_tuples(self):
  606.         # Map (proto, len(tuple)) to expected opcode.
  607.         expected_opcode = {(0, 0): pickle.TUPLE,
  608.                            (0, 1): pickle.TUPLE,
  609.                            (0, 2): pickle.TUPLE,
  610.                            (0, 3): pickle.TUPLE,
  611.                            (0, 4): pickle.TUPLE,
  612.  
  613.                            (1, 0): pickle.EMPTY_TUPLE,
  614.                            (1, 1): pickle.TUPLE,
  615.                            (1, 2): pickle.TUPLE,
  616.                            (1, 3): pickle.TUPLE,
  617.                            (1, 4): pickle.TUPLE,
  618.  
  619.                            (2, 0): pickle.EMPTY_TUPLE,
  620.                            (2, 1): pickle.TUPLE1,
  621.                            (2, 2): pickle.TUPLE2,
  622.                            (2, 3): pickle.TUPLE3,
  623.                            (2, 4): pickle.TUPLE,
  624.                           }
  625.         a = ()
  626.         b = (1,)
  627.         c = (1, 2)
  628.         d = (1, 2, 3)
  629.         e = (1, 2, 3, 4)
  630.         for proto in protocols:
  631.             for x in a, b, c, d, e:
  632.                 s = self.dumps(x, proto)
  633.                 y = self.loads(s)
  634.                 self.assertEqual(x, y, (proto, x, s, y))
  635.                 expected = expected_opcode[proto, len(x)]
  636.                 self.assertEqual(opcode_in_pickle(expected, s), True)
  637.  
  638.     def test_singletons(self):
  639.         # Map (proto, singleton) to expected opcode.
  640.         expected_opcode = {(0, None): pickle.NONE,
  641.                            (1, None): pickle.NONE,
  642.                            (2, None): pickle.NONE,
  643.  
  644.                            (0, True): pickle.INT,
  645.                            (1, True): pickle.INT,
  646.                            (2, True): pickle.NEWTRUE,
  647.  
  648.                            (0, False): pickle.INT,
  649.                            (1, False): pickle.INT,
  650.                            (2, False): pickle.NEWFALSE,
  651.                           }
  652.         for proto in protocols:
  653.             for x in None, False, True:
  654.                 s = self.dumps(x, proto)
  655.                 y = self.loads(s)
  656.                 self.assert_(x is y, (proto, x, s, y))
  657.                 expected = expected_opcode[proto, x]
  658.                 self.assertEqual(opcode_in_pickle(expected, s), True)
  659.  
  660.     def test_newobj_tuple(self):
  661.         x = MyTuple([1, 2, 3])
  662.         x.foo = 42
  663.         x.bar = "hello"
  664.         for proto in protocols:
  665.             s = self.dumps(x, proto)
  666.             y = self.loads(s)
  667.             self.assertEqual(tuple(x), tuple(y))
  668.             self.assertEqual(x.__dict__, y.__dict__)
  669.  
  670.     def test_newobj_list(self):
  671.         x = MyList([1, 2, 3])
  672.         x.foo = 42
  673.         x.bar = "hello"
  674.         for proto in protocols:
  675.             s = self.dumps(x, proto)
  676.             y = self.loads(s)
  677.             self.assertEqual(list(x), list(y))
  678.             self.assertEqual(x.__dict__, y.__dict__)
  679.  
  680.     def test_newobj_generic(self):
  681.         for proto in protocols:
  682.             for C in myclasses:
  683.                 B = C.__base__
  684.                 x = C(C.sample)
  685.                 x.foo = 42
  686.                 s = self.dumps(x, proto)
  687.                 y = self.loads(s)
  688.                 detail = (proto, C, B, x, y, type(y))
  689.                 self.assertEqual(B(x), B(y), detail)
  690.                 self.assertEqual(x.__dict__, y.__dict__, detail)
  691.  
  692.     # Register a type with copy_reg, with extension code extcode.  Pickle
  693.     # an object of that type.  Check that the resulting pickle uses opcode
  694.     # (EXT[124]) under proto 2, and not in proto 1.
  695.  
  696.     def produce_global_ext(self, extcode, opcode):
  697.         e = ExtensionSaver(extcode)
  698.         try:
  699.             copy_reg.add_extension(__name__, "MyList", extcode)
  700.             x = MyList([1, 2, 3])
  701.             x.foo = 42
  702.             x.bar = "hello"
  703.  
  704.             # Dump using protocol 1 for comparison.
  705.             s1 = self.dumps(x, 1)
  706.             self.assert_(__name__ in s1)
  707.             self.assert_("MyList" in s1)
  708.             self.assertEqual(opcode_in_pickle(opcode, s1), False)
  709.  
  710.             y = self.loads(s1)
  711.             self.assertEqual(list(x), list(y))
  712.             self.assertEqual(x.__dict__, y.__dict__)
  713.  
  714.             # Dump using protocol 2 for test.
  715.             s2 = self.dumps(x, 2)
  716.             self.assert_(__name__ not in s2)
  717.             self.assert_("MyList" not in s2)
  718.             self.assertEqual(opcode_in_pickle(opcode, s2), True)
  719.  
  720.             y = self.loads(s2)
  721.             self.assertEqual(list(x), list(y))
  722.             self.assertEqual(x.__dict__, y.__dict__)
  723.  
  724.         finally:
  725.             e.restore()
  726.  
  727.     def test_global_ext1(self):
  728.         self.produce_global_ext(0x00000001, pickle.EXT1)  # smallest EXT1 code
  729.         self.produce_global_ext(0x000000ff, pickle.EXT1)  # largest EXT1 code
  730.  
  731.     def test_global_ext2(self):
  732.         self.produce_global_ext(0x00000100, pickle.EXT2)  # smallest EXT2 code
  733.         self.produce_global_ext(0x0000ffff, pickle.EXT2)  # largest EXT2 code
  734.         self.produce_global_ext(0x0000abcd, pickle.EXT2)  # check endianness
  735.  
  736.     def test_global_ext4(self):
  737.         self.produce_global_ext(0x00010000, pickle.EXT4)  # smallest EXT4 code
  738.         self.produce_global_ext(0x7fffffff, pickle.EXT4)  # largest EXT4 code
  739.         self.produce_global_ext(0x12abcdef, pickle.EXT4)  # check endianness
  740.  
  741.     def test_list_chunking(self):
  742.         n = 10  # too small to chunk
  743.         x = range(n)
  744.         for proto in protocols:
  745.             s = self.dumps(x, proto)
  746.             y = self.loads(s)
  747.             self.assertEqual(x, y)
  748.             num_appends = count_opcode(pickle.APPENDS, s)
  749.             self.assertEqual(num_appends, proto > 0)
  750.  
  751.         n = 2500  # expect at least two chunks when proto > 0
  752.         x = range(n)
  753.         for proto in protocols:
  754.             s = self.dumps(x, proto)
  755.             y = self.loads(s)
  756.             self.assertEqual(x, y)
  757.             num_appends = count_opcode(pickle.APPENDS, s)
  758.             if proto == 0:
  759.                 self.assertEqual(num_appends, 0)
  760.             else:
  761.                 self.failUnless(num_appends >= 2)
  762.  
  763.     def test_dict_chunking(self):
  764.         n = 10  # too small to chunk
  765.         x = dict.fromkeys(range(n))
  766.         for proto in protocols:
  767.             s = self.dumps(x, proto)
  768.             y = self.loads(s)
  769.             self.assertEqual(x, y)
  770.             num_setitems = count_opcode(pickle.SETITEMS, s)
  771.             self.assertEqual(num_setitems, proto > 0)
  772.  
  773.         n = 2500  # expect at least two chunks when proto > 0
  774.         x = dict.fromkeys(range(n))
  775.         for proto in protocols:
  776.             s = self.dumps(x, proto)
  777.             y = self.loads(s)
  778.             self.assertEqual(x, y)
  779.             num_setitems = count_opcode(pickle.SETITEMS, s)
  780.             if proto == 0:
  781.                 self.assertEqual(num_setitems, 0)
  782.             else:
  783.                 self.failUnless(num_setitems >= 2)
  784.  
  785.     def test_simple_newobj(self):
  786.         x = object.__new__(SimpleNewObj)  # avoid __init__
  787.         x.abc = 666
  788.         for proto in protocols:
  789.             s = self.dumps(x, proto)
  790.             self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2)
  791.             y = self.loads(s)   # will raise TypeError if __init__ called
  792.             self.assertEqual(y.abc, 666)
  793.             self.assertEqual(x.__dict__, y.__dict__)
  794.  
  795.     def test_newobj_list_slots(self):
  796.         x = SlotList([1, 2, 3])
  797.         x.foo = 42
  798.         x.bar = "hello"
  799.         s = self.dumps(x, 2)
  800.         y = self.loads(s)
  801.         self.assertEqual(list(x), list(y))
  802.         self.assertEqual(x.__dict__, y.__dict__)
  803.         self.assertEqual(x.foo, y.foo)
  804.         self.assertEqual(x.bar, y.bar)
  805.  
  806.     def test_reduce_overrides_default_reduce_ex(self):
  807.         for proto in 0, 1, 2:
  808.             x = REX_one()
  809.             self.assertEqual(x._reduce_called, 0)
  810.             s = self.dumps(x, proto)
  811.             self.assertEqual(x._reduce_called, 1)
  812.             y = self.loads(s)
  813.             self.assertEqual(y._reduce_called, 0)
  814.  
  815.     def test_reduce_ex_called(self):
  816.         for proto in 0, 1, 2:
  817.             x = REX_two()
  818.             self.assertEqual(x._proto, None)
  819.             s = self.dumps(x, proto)
  820.             self.assertEqual(x._proto, proto)
  821.             y = self.loads(s)
  822.             self.assertEqual(y._proto, None)
  823.  
  824.     def test_reduce_ex_overrides_reduce(self):
  825.         for proto in 0, 1, 2:
  826.             x = REX_three()
  827.             self.assertEqual(x._proto, None)
  828.             s = self.dumps(x, proto)
  829.             self.assertEqual(x._proto, proto)
  830.             y = self.loads(s)
  831.             self.assertEqual(y._proto, None)
  832.  
  833. # Test classes for reduce_ex
  834.  
  835. class REX_one(object):
  836.     _reduce_called = 0
  837.     def __reduce__(self):
  838.         self._reduce_called = 1
  839.         return REX_one, ()
  840.     # No __reduce_ex__ here, but inheriting it from object
  841.  
  842. class REX_two(object):
  843.     _proto = None
  844.     def __reduce_ex__(self, proto):
  845.         self._proto = proto
  846.         return REX_two, ()
  847.     # No __reduce__ here, but inheriting it from object
  848.  
  849. class REX_three(object):
  850.     _proto = None
  851.     def __reduce_ex__(self, proto):
  852.         self._proto = proto
  853.         return REX_two, ()
  854.     def __reduce__(self):
  855.         raise TestFailed, "This __reduce__ shouldn't be called"
  856.  
  857. # Test classes for newobj
  858.  
  859. class MyInt(int):
  860.     sample = 1
  861.  
  862. class MyLong(long):
  863.     sample = 1L
  864.  
  865. class MyFloat(float):
  866.     sample = 1.0
  867.  
  868. class MyComplex(complex):
  869.     sample = 1.0 + 0.0j
  870.  
  871. class MyStr(str):
  872.     sample = "hello"
  873.  
  874. class MyUnicode(unicode):
  875.     sample = u"hello \u1234"
  876.  
  877. class MyTuple(tuple):
  878.     sample = (1, 2, 3)
  879.  
  880. class MyList(list):
  881.     sample = [1, 2, 3]
  882.  
  883. class MyDict(dict):
  884.     sample = {"a": 1, "b": 2}
  885.  
  886. myclasses = [MyInt, MyLong, MyFloat,
  887.              MyComplex,
  888.              MyStr, MyUnicode,
  889.              MyTuple, MyList, MyDict]
  890.  
  891.  
  892. class SlotList(MyList):
  893.     __slots__ = ["foo"]
  894.  
  895. class SimpleNewObj(object):
  896.     def __init__(self, a, b, c):
  897.         # raise an error, to make sure this isn't called
  898.         raise TypeError("SimpleNewObj.__init__() didn't expect to get called")
  899.  
  900. class AbstractPickleModuleTests(unittest.TestCase):
  901.  
  902.     def test_dump_closed_file(self):
  903.         import os
  904.         f = open(TESTFN, "w")
  905.         try:
  906.             f.close()
  907.             self.assertRaises(ValueError, self.module.dump, 123, f)
  908.         finally:
  909.             os.remove(TESTFN)
  910.  
  911.     def test_load_closed_file(self):
  912.         import os
  913.         f = open(TESTFN, "w")
  914.         try:
  915.             f.close()
  916.             self.assertRaises(ValueError, self.module.dump, 123, f)
  917.         finally:
  918.             os.remove(TESTFN)
  919.  
  920.     def test_highest_protocol(self):
  921.         # Of course this needs to be changed when HIGHEST_PROTOCOL changes.
  922.         self.assertEqual(self.module.HIGHEST_PROTOCOL, 2)
  923.  
  924.  
  925. class AbstractPersistentPicklerTests(unittest.TestCase):
  926.  
  927.     # This class defines persistent_id() and persistent_load()
  928.     # functions that should be used by the pickler.  All even integers
  929.     # are pickled using persistent ids.
  930.  
  931.     def persistent_id(self, object):
  932.         if isinstance(object, int) and object % 2 == 0:
  933.             self.id_count += 1
  934.             return str(object)
  935.         else:
  936.             return None
  937.  
  938.     def persistent_load(self, oid):
  939.         self.load_count += 1
  940.         object = int(oid)
  941.         assert object % 2 == 0
  942.         return object
  943.  
  944.     def test_persistence(self):
  945.         self.id_count = 0
  946.         self.load_count = 0
  947.         L = range(10)
  948.         self.assertEqual(self.loads(self.dumps(L)), L)
  949.         self.assertEqual(self.id_count, 5)
  950.         self.assertEqual(self.load_count, 5)
  951.  
  952.     def test_bin_persistence(self):
  953.         self.id_count = 0
  954.         self.load_count = 0
  955.         L = range(10)
  956.         self.assertEqual(self.loads(self.dumps(L, 1)), L)
  957.         self.assertEqual(self.id_count, 5)
  958.         self.assertEqual(self.load_count, 5)
  959.