add tool/sexp.py
[gnuk/gnuk.git] / tool / gnuk_upgrade.py
1 #! /usr/bin/python
2
3 """
4 gnuk_upgrade.py - a tool to upgrade firmware of Gnuk Token
5
6 Copyright (C) 2012 Free Software Initiative of Japan
7 Author: NIIBE Yutaka <gniibe@fsij.org>
8
9 This file is a part of Gnuk, a GnuPG USB Token implementation.
10
11 Gnuk is free software: you can redistribute it and/or modify it
12 under the terms of the GNU General Public License as published by
13 the Free Software Foundation, either version 3 of the License, or
14 (at your option) any later version.
15
16 Gnuk is distributed in the hope that it will be useful, but WITHOUT
17 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
18 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
19 License for more details.
20
21 You should have received a copy of the GNU General Public License
22 along with this program.  If not, see <http://www.gnu.org/licenses/>.
23 """
24
25 from struct import *
26 import sys, time, os, binascii, string
27
28 # INPUT: binary files (regnual_image, upgrade_firmware_image)
29
30 # Assume only single CCID device is attached to computer, and it's Gnuk Token
31
32 import usb
33
34 # USB class, subclass, protocol
35 CCID_CLASS = 0x0B
36 CCID_SUBCLASS = 0x00
37 CCID_PROTOCOL_0 = 0x00
38
39 def icc_compose(msg_type, data_len, slot, seq, param, data):
40     return pack('<BiBBBH', msg_type, data_len, slot, seq, 0, param) + data
41
42 def iso7816_compose(ins, p1, p2, data, cls=0x00):
43     data_len = len(data)
44     if data_len == 0:
45         return pack('>BBBB', cls, ins, p1, p2)
46     else:
47         return pack('>BBBBB', cls, ins, p1, p2, data_len) + data
48
49 class regnual(object):
50     def __init__(self, dev):
51         conf = dev.configurations[0]
52         intf_alt = conf.interfaces[0]
53         intf = intf_alt[0]
54         if intf.interfaceClass != 0xff:
55             raise ValueError, "Wrong interface class"
56         self.__devhandle = dev.open()
57         try:
58             self.__devhandle.setConfiguration(conf)
59         except:
60             pass
61         self.__devhandle.claimInterface(intf)
62         self.__devhandle.setAltInterface(intf)
63
64     def mem_info(self):
65         mem = self.__devhandle.controlMsg(requestType = 0xc0, request = 0,
66                                           value = 0, index = 0, buffer = 8,
67                                           timeout = 10000)
68         start = ((mem[3]*256 + mem[2])*256 + mem[1])*256 + mem[0]
69         end = ((mem[7]*256 + mem[6])*256 + mem[5])*256 + mem[4]
70         return (start, end)
71
72     def download(self, start, data):
73         addr = start
74         addr_end = (start + len(data)) & 0xffffff00
75         i = (addr - 0x08000000) / 0x100
76         j = 0
77         print "start %08x" % addr
78         print "end   %08x" % addr_end
79         while addr < addr_end:
80             print "# %08x: %d: %d : %d" % (addr, i, j, 256)
81             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
82                                         value = 0, index = 0,
83                                         buffer = data[j*256:j*256+256],
84                                         timeout = 10000)
85             crc32code = crc32(data[j*256:j*256+256])
86             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
87                                               value = 0, index = 0, buffer = 4,
88                                               timeout = 10000)
89             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
90             if (crc32code ^ r_value) != 0xffffffff:
91                 print "failure"
92             self.__devhandle.controlMsg(requestType = 0x40, request = 3,
93                                         value = i, index = 0,
94                                         buffer = None,
95                                         timeout = 10000)
96             time.sleep(0.010)
97             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
98                                               value = 0, index = 0, buffer = 4,
99                                               timeout = 10000)
100             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
101             if r_value == 0:
102                 print "failure"
103             i = i+1
104             j = j+1
105             addr = addr + 256
106         residue = len(data) % 256
107         if residue != 0:
108             print "# %08x: %d : %d" % (addr, i, residue)
109             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
110                                         value = 0, index = 0,
111                                         buffer = data[j*256:],
112                                         timeout = 10000)
113             crc32code = crc32(data[j*256:].ljust(256,chr(255)))
114             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
115                                               value = 0, index = 0, buffer = 4,
116                                               timeout = 10000)
117             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
118             if (crc32code ^ r_value) != 0xffffffff:
119                 print "failure"
120             self.__devhandle.controlMsg(requestType = 0x40, request = 3,
121                                         value = i, index = 0,
122                                         buffer = None,
123                                         timeout = 10000)
124             time.sleep(0.010)
125             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
126                                               value = 0, index = 0, buffer = 4,
127                                               timeout = 10000)
128             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
129             if r_value == 0:
130                 print "failure"
131
132     def protect(self):
133         self.__devhandle.controlMsg(requestType = 0x40, request = 4,
134                                     value = 0, index = 0, buffer = None,
135                                     timeout = 10000)
136         time.sleep(0.100)
137         res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
138                                           value = 0, index = 0, buffer = 4,
139                                           timeout = 10000)
140         r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
141         if r_value == 0:
142             print "protection failure"
143
144     def finish(self):
145         self.__devhandle.controlMsg(requestType = 0x40, request = 5,
146                                     value = 0, index = 0, buffer = None,
147                                     timeout = 10000)
148
149     def reset_device(self):
150         try:
151             self.__devhandle.reset()
152         except:
153             pass
154
155 # This class only supports Gnuk (for now) 
156 class gnuk_token(object):
157     def __init__(self, device, configuration, interface):
158         """
159         __init__(device, configuration, interface) -> None
160         Initialize the device.
161         device: usb.Device object.
162         configuration: configuration number.
163         interface: usb.Interface object representing the interface and altenate setting.
164         """
165         if interface.interfaceClass != CCID_CLASS:
166             raise ValueError, "Wrong interface class"
167         if interface.interfaceSubClass != CCID_SUBCLASS:
168             raise ValueError, "Wrong interface sub class"
169         self.__devhandle = device.open()
170         try:
171             self.__devhandle.setConfiguration(configuration)
172         except:
173             pass
174         self.__devhandle.claimInterface(interface)
175         self.__devhandle.setAltInterface(interface)
176
177         self.__intf = interface.interfaceNumber
178         self.__alt = interface.alternateSetting
179         self.__conf = configuration
180
181         self.__bulkout = 1
182         self.__bulkin  = 0x81
183
184         self.__timeout = 10000
185         self.__seq = 0
186
187     def reset_device(self):
188         try:
189             self.__devhandle.reset()
190         except:
191             pass
192
193     def stop_gnuk(self):
194         self.__devhandle.releaseInterface()
195         self.__devhandle.setConfiguration(0)
196         return
197
198     def mem_info(self):
199         mem = self.__devhandle.controlMsg(requestType = 0xc0, request = 0,
200                                           value = 0, index = 0, buffer = 8,
201                                           timeout = 10)
202         start = ((mem[3]*256 + mem[2])*256 + mem[1])*256 + mem[0]
203         end = ((mem[7]*256 + mem[6])*256 + mem[5])*256 + mem[4]
204         return (start, end)
205
206     def download(self, start, data):
207         addr = start
208         addr_end = (start + len(data)) & 0xffffff00
209         i = (addr - 0x20000000) / 0x100
210         j = 0
211         print "start %08x" % addr
212         print "end   %08x" % addr_end
213         while addr < addr_end:
214             print "# %08x: %d : %d" % (addr, i, 256)
215             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
216                                         value = i, index = 0,
217                                         buffer = data[j*256:j*256+256],
218                                         timeout = 10)
219             i = i+1
220             j = j+1
221             addr = addr + 256
222         residue = len(data) % 256
223         if residue != 0:
224             print "# %08x: %d : %d" % (addr, i, residue)
225             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
226                                         value = i, index = 0,
227                                         buffer = data[j*256:],
228                                         timeout = 10)
229
230     def execute(self, last_addr):
231         i = (last_addr - 0x20000000) / 0x100
232         o = (last_addr - 0x20000000) % 0x100
233         self.__devhandle.controlMsg(requestType = 0x40, request = 2,
234                                     value = i, index = o, buffer = None,
235                                     timeout = 10)
236
237     def icc_get_result(self):
238         msg = self.__devhandle.bulkRead(self.__bulkin, 1024, self.__timeout)
239         if len(msg) < 10:
240             raise ValueError, "icc_get_result"
241         msg_type = msg[0]
242         data_len = msg[1] + (msg[2]<<8) + (msg[3]<<16) + (msg[4]<<24)
243         slot = msg[5]
244         seq = msg[6]
245         status = msg[7]
246         error = msg[8]
247         chain = msg[9]
248         data = msg[10:]
249         # XXX: check msg_type, data_len, slot, seq, error
250         return (status, chain, data)
251
252     def icc_get_status(self):
253         msg = icc_compose(0x65, 0, 0, self.__seq, 0, "")
254         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
255         self.__seq += 1
256         status, chain, data = self.icc_get_result()
257         # XXX: check chain, data
258         return status
259
260     def icc_power_on(self):
261         msg = icc_compose(0x62, 0, 0, self.__seq, 0, "")
262         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
263         self.__seq += 1
264         status, chain, data = self.icc_get_result()
265         # XXX: check status, chain
266         return data             # ATR
267
268     def icc_power_off(self):
269         msg = icc_compose(0x63, 0, 0, self.__seq, 0, "")
270         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
271         self.__seq += 1
272         status, chain, data = self.icc_get_result()
273         # XXX: check chain, data
274         return status
275
276     def icc_send_data_block(self, data):
277         msg = icc_compose(0x6f, len(data), 0, self.__seq, 0, data)
278         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
279         self.__seq += 1
280         return self.icc_get_result()
281
282     def icc_send_cmd(self, data):
283         status, chain, data_rcv = self.icc_send_data_block(data)
284         if chain == 0:
285             return data_rcv
286         elif chain == 1:
287             d = data_rcv
288             while True:
289                 msg = icc_compose(0x6f, 0, 0, self.__seq, 0x10, "")
290                 self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
291                 self.__seq += 1
292                 status, chain, data_rcv = self.icc_get_result()
293                 # XXX: check status
294                 d += data_rcv
295                 if chain == 2:
296                     break
297                 elif chain == 3:
298                     continue
299                 else:
300                     raise ValueError, "icc_send_cmd chain"
301             return d
302         else:
303             raise ValueError, "icc_send_cmd"
304
305     def cmd_get_response(self, expected_len):
306         cmd_data = iso7816_compose(0xc0, 0x00, 0x00, '') + pack('>B', expected_len)
307         response = self.icc_send_cmd(cmd_data)
308         return response[:-2]
309
310     def cmd_verify(self, who, passwd):
311         cmd_data = iso7816_compose(0x20, 0x00, 0x80+who, passwd)
312         sw = self.icc_send_cmd(cmd_data)
313         if len(sw) != 2:
314             raise ValueError, sw
315         if not (sw[0] == 0x90 and sw[1] == 0x00):
316             raise ValueError, sw
317
318     def cmd_select_openpgp(self):
319         cmd_data = iso7816_compose(0xa4, 0x04, 0x0c, "\xD2\x76\x00\x01\x24\x01")
320         sw = self.icc_send_cmd(cmd_data)
321         if len(sw) != 2:
322             raise ValueError, sw
323         if not (sw[0] == 0x90 and sw[1] == 0x00):
324             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
325
326     def cmd_external_authenticate(self, signed):
327         cmd_data = iso7816_compose(0x82, 0x00, 0x00, signed[0:128], cls=0x10)
328         sw = self.icc_send_cmd(cmd_data)
329         if len(sw) != 2:
330             raise ValueError, sw
331         if not (sw[0] == 0x90 and sw[1] == 0x00):
332             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
333         cmd_data = iso7816_compose(0x82, 0x00, 0x00, signed[128:])
334         sw = self.icc_send_cmd(cmd_data)
335         if len(sw) != 2:
336             raise ValueError, sw
337         if not (sw[0] == 0x90 and sw[1] == 0x00):
338             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
339
340     def cmd_get_challenge(self):
341         cmd_data = iso7816_compose(0x84, 0x00, 0x00, '')
342         sw = self.icc_send_cmd(cmd_data)
343         if len(sw) != 2:
344             raise ValueError, sw
345         if sw[0] != 0x61:
346             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
347         return self.cmd_get_response(sw[1])
348
349 def compare(data_original, data_in_device):
350     i = 0 
351     for d in data_original:
352         if ord(d) != data_in_device[i]:
353             raise ValueError, "verify failed at %08x" % i
354         i += 1
355
356 def ccid_devices():
357     busses = usb.busses()
358     for bus in busses:
359         devices = bus.devices
360         for dev in devices:
361             for config in dev.configurations:
362                 for intf in config.interfaces:
363                     for alt in intf:
364                         if alt.interfaceClass == CCID_CLASS and \
365                                 alt.interfaceSubClass == CCID_SUBCLASS and \
366                                 alt.interfaceProtocol == CCID_PROTOCOL_0:
367                             yield dev, config, alt
368
369 USB_VENDOR_FSIJ=0x234b
370 USB_PRODUCT_GNUK=0x0000
371
372 def gnuk_devices():
373     busses = usb.busses()
374     for bus in busses:
375         devices = bus.devices
376         for dev in devices:
377             if dev.idVendor != USB_VENDOR_FSIJ:
378                 continue
379             if dev.idProduct != USB_PRODUCT_GNUK:
380                 continue
381             yield dev
382
383 def to_string(t):
384     result = ""
385     for c in t:
386         result += chr(c)
387     return result
388
389 from subprocess import check_output
390
391 SHA256_OID_PREFIX="3031300d060960864801650304020105000420"
392
393 # When user specify KEYGRIP, use it.  Or else, connect to SCD directly.
394 def gpg_sign(keygrip, hash):
395     if keygrip:
396         result = check_output(["gpg-connect-agent",
397                                "SIGKEY %s" % keygrip,
398                                "SETHASH --hash=sha256 %s" % hash,
399                                "PKSIGN --hash=sha256", "/bye"])
400     else:
401         result = check_output(["gpg-connect-agent",
402                                "SCD SETDATA " + SHA256_OID_PREFIX + hash,
403                                "SCD PKAUTH OPENPGP.3",
404                                "/bye"])
405     signed = ""
406     while True:
407         i = result.find('%')
408         if i < 0:
409             signed += result
410             break
411         hex_str = result[i+1:i+3]
412         signed += result[0:i]
413         signed += chr(int(hex_str,16))
414         result = result[i+3:]
415
416     if keygrip:
417         pos = signed.index("D (7:sig-val(3:rsa(1:s256:") + 26
418         signed = signed[pos:-7]
419     else:
420         pos = signed.index("D ") + 2
421         signed = signed[pos:-4]     # \nOK\n
422     if len(signed) != 256:
423         raise ValueError, binascii.hexlify(signed)
424     return signed
425
426 def UNSIGNED(n):
427     return n & 0xffffffff
428
429 def crc32(bytestr):
430     crc = binascii.crc32(bytestr)
431     return UNSIGNED(crc)
432
433 def main(keygrip, data_regnual, data_upgrade):
434     l = len(data_regnual)
435     if (l & 0x03) != 0:
436         data_regnual = data_regnual.ljust(l + 4 - (l & 0x03), chr(0))
437     crc32code = crc32(data_regnual)
438     print "CRC32: %04x\n" % crc32code
439     data_regnual += pack('<I', crc32code)
440     for (dev, config, intf) in ccid_devices():
441         try:
442             icc = gnuk_token(dev, config, intf)
443             print "Device: ", dev.filename
444             print "Configuration: ", config.value
445             print "Interface: ", intf.interfaceNumber
446             break
447         except:
448             icc = None
449     if icc.icc_get_status() == 2:
450         raise ValueError, "No ICC present"
451     elif icc.icc_get_status() == 1:
452         icc.icc_power_on()
453     icc.cmd_select_openpgp()
454     challenge = icc.cmd_get_challenge()
455     signed = gpg_sign(keygrip, binascii.hexlify(to_string(challenge)))
456     icc.cmd_external_authenticate(signed)
457     icc.stop_gnuk()
458     mem_info = icc.mem_info()
459     print "%08x:%08x" % mem_info
460     print "Downloading flash upgrade program..."
461     icc.download(mem_info[0], data_regnual)
462     print "Run flash upgrade program..."
463     icc.execute(mem_info[0] + len(data_regnual) - 4)
464     #
465     time.sleep(3)
466     icc.reset_device()
467     del icc
468     icc = None
469     #
470     print "Wait 3 seconds..."
471     time.sleep(3)
472     # Then, send upgrade program...
473     reg = None
474     for dev in gnuk_devices():
475         try:
476             reg = regnual(dev)
477             print "Device: ", dev.filename
478             break
479         except:
480             pass
481     mem_info = reg.mem_info()
482     print "%08x:%08x" % mem_info
483     print "Downloading the program"
484     reg.download(mem_info[0], data_upgrade)
485     reg.protect()
486     reg.finish()
487     reg.reset_device()
488     return 0
489
490
491 if __name__ == '__main__':
492     keygrip = None
493     if sys.argv[1] == '-k':
494         sys.argv.pop(1)
495         keygrip = sys.argv[1]
496         sys.argv.pop(1)
497     filename_regnual = sys.argv[1]
498     filename_upgrade = sys.argv[2]
499     f = open(filename_regnual)
500     data_regnual = f.read()
501     f.close()
502     print "%s: %d" % (filename_regnual, len(data_regnual))
503     f = open(filename_upgrade)
504     data_upgrade = f.read()
505     f.close()
506     print "%s: %d" % (filename_upgrade, len(data_upgrade))
507     main(keygrip, data_regnual, data_upgrade[4096:])