682d877e891fea6ab5273f419d3b9178a1847d8d
[gnuk/gnuk.git] / tool / gnuk_upgrade.py
1 #! /usr/bin/python
2
3 """
4 gnuk_upgrade.py - a tool to upgrade firmware of Gnuk Token
5
6 Copyright (C) 2012 Free Software Initiative of Japan
7 Author: NIIBE Yutaka <gniibe@fsij.org>
8
9 This file is a part of Gnuk, a GnuPG USB Token implementation.
10
11 Gnuk is free software: you can redistribute it and/or modify it
12 under the terms of the GNU General Public License as published by
13 the Free Software Foundation, either version 3 of the License, or
14 (at your option) any later version.
15
16 Gnuk is distributed in the hope that it will be useful, but WITHOUT
17 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
18 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
19 License for more details.
20
21 You should have received a copy of the GNU General Public License
22 along with this program.  If not, see <http://www.gnu.org/licenses/>.
23 """
24
25 from intel_hex import *
26 from struct import *
27 import sys, time, os, binascii, string
28
29 # INPUT: binary file
30
31 # Assume only single CCID device is attached to computer, and it's Gnuk Token
32
33 import usb
34
35 # USB class, subclass, protocol
36 CCID_CLASS = 0x0B
37 CCID_SUBCLASS = 0x00
38 CCID_PROTOCOL_0 = 0x00
39
40 def icc_compose(msg_type, data_len, slot, seq, param, data):
41     return pack('<BiBBBH', msg_type, data_len, slot, seq, 0, param) + data
42
43 def iso7816_compose(ins, p1, p2, data):
44     cls = 0x00 
45     data_len = len(data)
46     if data_len == 0:
47         return pack('>BBBB', cls, ins, p1, p2)
48     else:
49         return pack('>BBBBB', cls, ins, p1, p2, data_len) + data
50
51 # This class only supports Gnuk (for now) 
52 class gnuk_token:
53     def __init__(self, device, configuration, interface):
54         """
55         __init__(device, configuration, interface) -> None
56         Initialize the device.
57         device: usb.Device object.
58         configuration: configuration number.
59         interface: usb.Interface object representing the interface and altenate setting.
60         """
61         if interface.interfaceClass != CCID_CLASS:
62             raise ValueError, "Wrong interface class"
63         if interface.interfaceSubClass != CCID_SUBCLASS:
64             raise ValueError, "Wrong interface sub class"
65         self.__devhandle = device.open()
66         try:
67             self.__devhandle.setConfiguration(configuration)
68         except:
69             pass
70         self.__devhandle.claimInterface(interface)
71         self.__devhandle.setAltInterface(interface)
72
73         self.__intf = interface.interfaceNumber
74         self.__alt = interface.alternateSetting
75         self.__conf = configuration
76
77         self.__bulkout = 1
78         self.__bulkin  = 0x81
79
80         self.__timeout = 10000
81         self.__seq = 0
82
83     def stop_icc(self):
84         # XXX: need to disclaim interface and close device and re-open???
85         # self.__devhandle.setConfiguration(0)
86         return
87
88     def mem_info(self):
89         mem = self.__devhandle.controlMsg(requestType = 0xc0, request = 0,
90                                           value = 0, index = 0, buffer = 8,
91                                           timeout = 10)
92         start = ((mem[3]*256 + mem[2])*256 + mem[1])*256 + mem[0]
93         end = ((mem[7]*256 + mem[6])*256 + mem[5])*256 + mem[4]
94         return (start, end)
95
96     def icc_get_result(self):
97         msg = self.__devhandle.bulkRead(self.__bulkin, 1024, self.__timeout)
98         if len(msg) < 10:
99             raise ValueError, "icc_get_result"
100         msg_type = msg[0]
101         data_len = msg[1] + (msg[2]<<8) + (msg[3]<<16) + (msg[4]<<24)
102         slot = msg[5]
103         seq = msg[6]
104         status = msg[7]
105         error = msg[8]
106         chain = msg[9]
107         data = msg[10:]
108         # XXX: check msg_type, data_len, slot, seq, error
109         return (status, chain, data)
110
111     def icc_get_status(self):
112         msg = icc_compose(0x65, 0, 0, self.__seq, 0, "")
113         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
114         self.__seq += 1
115         status, chain, data = self.icc_get_result()
116         # XXX: check chain, data
117         return status
118
119     def icc_power_on(self):
120         msg = icc_compose(0x62, 0, 0, self.__seq, 0, "")
121         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
122         self.__seq += 1
123         status, chain, data = self.icc_get_result()
124         # XXX: check status, chain
125         return data             # ATR
126
127     def icc_power_off(self):
128         msg = icc_compose(0x63, 0, 0, self.__seq, 0, "")
129         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
130         self.__seq += 1
131         status, chain, data = self.icc_get_result()
132         # XXX: check chain, data
133         return status
134
135     def icc_send_data_block(self, data):
136         msg = icc_compose(0x6f, len(data), 0, self.__seq, 0, data)
137         self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
138         self.__seq += 1
139         return self.icc_get_result()
140
141     def icc_send_cmd(self, data):
142         status, chain, data_rcv = self.icc_send_data_block(data)
143         if chain == 0:
144             return data_rcv
145         elif chain == 1:
146             d = data_rcv
147             while True:
148                 msg = icc_compose(0x6f, 0, 0, self.__seq, 0x10, "")
149                 self.__devhandle.bulkWrite(self.__bulkout, msg, self.__timeout)
150                 self.__seq += 1
151                 status, chain, data_rcv = self.icc_get_result()
152                 # XXX: check status
153                 d += data_rcv
154                 if chain == 2:
155                     break
156                 elif chain == 3:
157                     continue
158                 else:
159                     raise ValueError, "icc_send_cmd chain"
160             return d
161         else:
162             raise ValueError, "icc_send_cmd"
163
164     def cmd_get_response(self, expected_len):
165         cmd_data = iso7816_compose(0xc0, 0x00, 0x00, '') + pack('>B', expected_len)
166         response = self.icc_send_cmd(cmd_data)
167         return response[:-2]
168
169     def cmd_verify(self, who, passwd):
170         cmd_data = iso7816_compose(0x20, 0x00, 0x80+who, passwd)
171         sw = self.icc_send_cmd(cmd_data)
172         if len(sw) != 2:
173             raise ValueError, sw
174         if not (sw[0] == 0x90 and sw[1] == 0x00):
175             raise ValueError, sw
176
177     def cmd_select_openpgp(self):
178         cmd_data = iso7816_compose(0xa4, 0x04, 0x0c, "\xD2\x76\x00\x01\x24\x01")
179         sw = self.icc_send_cmd(cmd_data)
180         if len(sw) != 2:
181             raise ValueError, sw
182         if not (sw[0] == 0x90 and sw[1] == 0x00):
183             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
184
185     def cmd_external_authenticate(self, signed):
186         cmd_data = iso7816_compose(0x82, 0x00, 0x00, signed)
187         sw = self.icc_send_cmd(cmd_data)
188         if len(sw) != 2:
189             raise ValueError, sw
190         if not (sw[0] == 0x90 and sw[1] == 0x00):
191             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
192
193     def cmd_get_challenge(self):
194         cmd_data = iso7816_compose(0x84, 0x00, 0x00, '')
195         sw = self.icc_send_cmd(cmd_data)
196         if len(sw) != 2:
197             raise ValueError, sw
198         if sw[0] != 0x61:
199             raise ValueError, ("%02x%02x" % (sw[0], sw[1]))
200         return self.cmd_get_response(sw[1])
201
202 def compare(data_original, data_in_device):
203     i = 0 
204     for d in data_original:
205         if ord(d) != data_in_device[i]:
206             raise ValueError, "verify failed at %08x" % i
207         i += 1
208
209 def get_device():
210     busses = usb.busses()
211     for bus in busses:
212         devices = bus.devices
213         for dev in devices:
214             for config in dev.configurations:
215                 for intf in config.interfaces:
216                     for alt in intf:
217                         if alt.interfaceClass == CCID_CLASS and \
218                                 alt.interfaceSubClass == CCID_SUBCLASS and \
219                                 alt.interfaceProtocol == CCID_PROTOCOL_0:
220                             return dev, config, alt
221     raise ValueError, "Device not found"
222
223 def to_string(t):
224     result = ""
225     for c in t:
226         result += chr(c)
227     return result
228
229 def main(passwd, data_regnual, data_upgrade):
230     dev, config, intf = get_device()
231     print "Device: ", dev.filename
232     print "Configuration: ", config.value
233     print "Interface: ", intf.interfaceNumber
234     icc = gnuk_token(dev, config, intf)
235     if icc.icc_get_status() == 2:
236         raise ValueError, "No ICC present"
237     elif icc.icc_get_status() == 1:
238         icc.icc_power_on()
239     icc.cmd_verify(3, passwd)
240     icc.cmd_select_openpgp()
241     challenge = icc.cmd_get_challenge()
242     signed = to_string(challenge)
243     icc.cmd_external_authenticate(signed)
244     icc.stop_icc() # disable all interfaces but control pipe
245     mem_info icc.mem_info()
246     print "%08x: %08x" % mem_info
247     # download flash install program
248     # exec
249     # ...
250     print "Downloading flash upgrade program"
251     # Then, send upgrade program...
252     print "Downloading the program"
253     return 0
254
255 DEFAULT_PW3 = "12345678"
256
257 if __name__ == '__main__':
258     passwd = DEFAULT_PW3
259     if sys.argv[1] == '-p':
260         from getpass import getpass
261         passwd = getpass("Admin password: ")
262         sys.argv.pop(1)
263     filename_regnual = sys.argv[1]
264     filename_upgrade = sys.argv[2]
265     f = open(filename_regnual)
266     data_regnual = f.read()
267     f.close()
268     print "%s: %d" % (filename_regnual, len(data_regnual))
269     f = open(filename_upgrade)
270     data_upgrade = f.read()
271     f.close()
272     print "%s: %d" % (filename_upgrade, len(data_upgrade))
273     main(passwd, data_regnual, data_upgrade)