more upgrade
[gnuk/gnuk.git] / tool / gnuk_upgrade.py
1 #! /usr/bin/python
2
3 """
4 gnuk_upgrade.py - a tool to upgrade firmware of Gnuk Token
5
6 Copyright (C) 2012 Free Software Initiative of Japan
7 Author: NIIBE Yutaka <gniibe@fsij.org>
8
9 This file is a part of Gnuk, a GnuPG USB Token implementation.
10
11 Gnuk is free software: you can redistribute it and/or modify it
12 under the terms of the GNU General Public License as published by
13 the Free Software Foundation, either version 3 of the License, or
14 (at your option) any later version.
15
16 Gnuk is distributed in the hope that it will be useful, but WITHOUT
17 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
18 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
19 License for more details.
20
21 You should have received a copy of the GNU General Public License
22 along with this program.  If not, see <http://www.gnu.org/licenses/>.
23 """
24
25 from intel_hex import *
26 from struct import *
27 import sys, time, os, binascii, string
28
29 # INPUT: binary file
30
31 # Assume only single CCID device is attached to computer, and it's Gnuk Token
32
33 import usb
34
35 # USB class, subclass, protocol
36 CCID_CLASS = 0x0B
37 CCID_SUBCLASS = 0x00
38 CCID_PROTOCOL_0 = 0x00
39
40 def icc_compose(msg_type, data_len, slot, seq, param, data):
41     return pack('<BiBBBH', msg_type, data_len, slot, seq, 0, param) + data
42
43 def iso7816_compose(ins, p1, p2, data):
44     cls = 0x00 
45     data_len = len(data)
46     if data_len == 0:
47         return pack('>BBBB', cls, ins, p1, p2)
48     else:
49         return pack('>BBBBB', cls, ins, p1, p2, data_len) + data
50
51 # This class only supports Gnuk (for now) 
52 class gnuk_token:
53     def __init__(self, device, configuration, interface):
54         """
55         __init__(device, configuration, interface) -> None
56         Initialize the device.
57         device: usb.Device object.
58         configuration: configuration number.
59         interface: usb.Interface object representing the interface and altenate setting.
60         """
61         if interface.interfaceClass != CCID_CLASS:
62             raise ValueError, "Wrong interface class"
63         if interface.interfaceSubClass != CCID_SUBCLASS:
64             raise ValueError, "Wrong interface sub class"
65         self.__devhandle = device.open()
66         try:
67             self.__devhandle.setConfiguration(configuration)
68         except:
69             pass
70         self.__devhandle.claimInterface(interface)
71         self.__devhandle.setAltInterface(interface)
72
73         self.__intf = interface.interfaceNumber
74         self.__alt = interface.alternateSetting
75         self.__conf = configuration
76
77         self.__bulkout = 1
78         self.__bulkin  = 0x81
79
80         self.__timeout = 10000
81         self.__seq = 0
82
83     def stop_icc(self):
84         # XXX: need to disclaim interface and close device and re-open???
85         # self.__devhandle.setConfiguration(0)
86         return
87
88     def mem_info(self):
89         mem = self.__devhandle.controlMsg(requestType = 0xc0, request = 0,
90                                           value = 0, index = 0, buffer = 8,
91                                           timeout = 10)
92         start = ((mem[3]*256 + mem[2])*256 + mem[1])*256 + mem[0]
93         end = ((mem[7]*256 + mem[6])*256 + mem[5])*256 + mem[4]
94         return (start, end)
95
96     def download(self, start, data):
97         addr = start
98         addr_end = start + len(data)
99         i = (addr - 0x20000000) / 0x100
100         print "start %08x" % addr
101         print "end   %08x" % addr_end
102         while addr < addr_end:
103             print "# %08x: %d : %d" % (addr, i, 256)
104             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
105                                         value = i, index = 0,
106                                         buffer = data[i*256:i*256+256],
107                                         timeout = 10)
108             i = i+1
109             addr = addr + 256
110         residue = len(data) % 256
111         if residue != 0:
112             print "# %08x: %d : %d" % (addr, i, residue)
113             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
114                                         value = i, index = 0,
115                                         buffer = data[i*256:],
116                                         timeout = 10)
117
118     def execute(self):
119         self.__devhandle.controlMsg(requestType = 0x40, request = 2,
120                                     value = 0, index = 0, buffer = None,
121                                     timeout = 10)
122
123     def icc_get_result(self):
124         msg = self.__devhandle.bulkRead(self.__bulkin, 1024, self.__timeout)
125         if len(msg) < 10:
126             raise ValueError, "icc_get_result"
127         msg_type = msg[0]
128         data_len = msg[1] + (msg[2]<<8) + (msg[3]<<16) + (msg[4]<<24)
129         slot = msg[5]
130         seq = msg[6]
131         status = msg[7]
132         error = msg[8]
133         chain = msg[9]
134         data = msg[10:]
135         # XXX: check msg_type, data_len, slot, seq, error
136         return (status, chain, data)
137
138     def icc_get_status(self):
139         msg = icc_compose(0x65, 0, 0, self.__seq, 0, "")
140         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
141         self.__seq += 1
142         status, chain, data = self.icc_get_result()
143         # XXX: check chain, data
144         return status
145
146     def icc_power_on(self):
147         msg = icc_compose(0x62, 0, 0, self.__seq, 0, "")
148         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
149         self.__seq += 1
150         status, chain, data = self.icc_get_result()
151         # XXX: check status, chain
152         return data             # ATR
153
154     def icc_power_off(self):
155         msg = icc_compose(0x63, 0, 0, self.__seq, 0, "")
156         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
157         self.__seq += 1
158         status, chain, data = self.icc_get_result()
159         # XXX: check chain, data
160         return status
161
162     def icc_send_data_block(self, data):
163         msg = icc_compose(0x6f, len(data), 0, self.__seq, 0, data)
164         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
165         self.__seq += 1
166         return self.icc_get_result()
167
168     def icc_send_cmd(self, data):
169         status, chain, data_rcv = self.icc_send_data_block(data)
170         if chain == 0:
171             return data_rcv
172         elif chain == 1:
173             d = data_rcv
174             while True:
175                 msg = icc_compose(0x6f, 0, 0, self.__seq, 0x10, "")
176                 self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
177                 self.__seq += 1
178                 status, chain, data_rcv = self.icc_get_result()
179                 # XXX: check status
180                 d += data_rcv
181                 if chain == 2:
182                     break
183                 elif chain == 3:
184                     continue
185                 else:
186                     raise ValueError, "icc_send_cmd chain"
187             return d
188         else:
189             raise ValueError, "icc_send_cmd"
190
191     def cmd_get_response(self, expected_len):
192         cmd_data = iso7816_compose(0xc0, 0x00, 0x00, '') + pack('>B', expected_len)
193         response = self.icc_send_cmd(cmd_data)
194         return response[:-2]
195
196     def cmd_verify(self, who, passwd):
197         cmd_data = iso7816_compose(0x20, 0x00, 0x80+who, passwd)
198         sw = self.icc_send_cmd(cmd_data)
199         if len(sw) != 2:
200             raise ValueError, sw
201         if not (sw[0] == 0x90 and sw[1] == 0x00):
202             raise ValueError, sw
203
204     def cmd_select_openpgp(self):
205         cmd_data = iso7816_compose(0xa4, 0x04, 0x0c, "\xD2\x76\x00\x01\x24\x01")
206         sw = self.icc_send_cmd(cmd_data)
207         if len(sw) != 2:
208             raise ValueError, sw
209         if not (sw[0] == 0x90 and sw[1] == 0x00):
210             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
211
212     def cmd_external_authenticate(self, signed):
213         cmd_data = iso7816_compose(0x82, 0x00, 0x00, signed)
214         sw = self.icc_send_cmd(cmd_data)
215         if len(sw) != 2:
216             raise ValueError, sw
217         if not (sw[0] == 0x90 and sw[1] == 0x00):
218             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
219
220     def cmd_get_challenge(self):
221         cmd_data = iso7816_compose(0x84, 0x00, 0x00, '')
222         sw = self.icc_send_cmd(cmd_data)
223         if len(sw) != 2:
224             raise ValueError, sw
225         if sw[0] != 0x61:
226             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
227         return self.cmd_get_response(sw[1])
228
229 def compare(data_original, data_in_device):
230     i = 0 
231     for d in data_original:
232         if ord(d) != data_in_device[i]:
233             raise ValueError, "verify failed at %08x" % i
234         i += 1
235
236 def get_device():
237     busses = usb.busses()
238     for bus in busses:
239         devices = bus.devices
240         for dev in devices:
241             for config in dev.configurations:
242                 for intf in config.interfaces:
243                     for alt in intf:
244                         if alt.interfaceClass == CCID_CLASS and \
245                                 alt.interfaceSubClass == CCID_SUBCLASS and \
246                                 alt.interfaceProtocol == CCID_PROTOCOL_0:
247                             return dev, config, alt
248     raise ValueError, "Device not found"
249
250 def to_string(t):
251     result = ""
252     for c in t:
253         result += chr(c)
254     return result
255
256 def main(passwd, data_regnual, data_upgrade):
257     dev, config, intf = get_device()
258     print "Device: ", dev.filename
259     print "Configuration: ", config.value
260     print "Interface: ", intf.interfaceNumber
261     icc = gnuk_token(dev, config, intf)
262     if icc.icc_get_status() == 2:
263         raise ValueError, "No ICC present"
264     elif icc.icc_get_status() == 1:
265         icc.icc_power_on()
266     icc.cmd_verify(3, passwd)
267     icc.cmd_select_openpgp()
268     challenge = icc.cmd_get_challenge()
269     signed = to_string(challenge)
270     icc.cmd_external_authenticate(signed)
271     icc.stop_icc() # disable all interfaces but control pipe
272     mem_info = icc.mem_info()
273     print "%08x:%08x" % mem_info
274     print "Downloading flash upgrade program..."
275     icc.download(mem_info[0], data_regnual)
276     print "Run flash upgrade program..."
277     icc.execute()
278     # Then, send upgrade program...
279     print "Downloading the program"
280     return 0
281
282 DEFAULT_PW3 = "12345678"
283
284 if __name__ == '__main__':
285     passwd = DEFAULT_PW3
286     if sys.argv[1] == '-p':
287         from getpass import getpass
288         passwd = getpass("Admin password: ")
289         sys.argv.pop(1)
290     filename_regnual = sys.argv[1]
291     filename_upgrade = sys.argv[2]
292     f = open(filename_regnual)
293     data_regnual = f.read()
294     f.close()
295     print "%s: %d" % (filename_regnual, len(data_regnual))
296     f = open(filename_upgrade)
297     data_upgrade = f.read()
298     f.close()
299     print "%s: %d" % (filename_upgrade, len(data_upgrade))
300     main(passwd, data_regnual, data_upgrade)