channel selection
[gnuk/neug.git] / tool / neug_upgrade.py
1 #! /usr/bin/python
2
3 """
4 neug_upgrade.py - a tool to upgrade firmware of Gnuk Token / NeuG device
5
6 Copyright (C) 2012 Free Software Initiative of Japan
7 Author: NIIBE Yutaka <gniibe@fsij.org>
8
9 This file is a part of NeuG, a TRNG implementation.
10
11 Gnuk is free software: you can redistribute it and/or modify it
12 under the terms of the GNU General Public License as published by
13 the Free Software Foundation, either version 3 of the License, or
14 (at your option) any later version.
15
16 Gnuk is distributed in the hope that it will be useful, but WITHOUT
17 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
18 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
19 License for more details.
20
21 You should have received a copy of the GNU General Public License
22 along with this program.  If not, see <http://www.gnu.org/licenses/>.
23 """
24
25 from struct import *
26 import sys, time, os, binascii, string
27
28 # INPUT: binary file
29
30 # Assume only single NeuG device is attached to computer
31
32 import usb
33
34 # USB class, subclass, protocol
35 COM_CLASS = 0x0a
36 COM_SUBCLASS = 0x00
37 COM_PROTOCOL_0 = 0x00
38
39 class regnual(object):
40     def __init__(self, dev):
41         conf = dev.configurations[0]
42         intf_alt = conf.interfaces[0]
43         intf = intf_alt[0]
44         if intf.interfaceClass != 0xff:
45             raise ValueError, "Wrong interface class"
46         self.__devhandle = dev.open()
47         try:
48             self.__devhandle.setConfiguration(conf)
49         except:
50             pass
51         self.__devhandle.claimInterface(intf)
52         self.__devhandle.setAltInterface(intf)
53
54     def mem_info(self):
55         mem = self.__devhandle.controlMsg(requestType = 0xc0, request = 0,
56                                           value = 0, index = 0, buffer = 8,
57                                           timeout = 10000)
58         start = ((mem[3]*256 + mem[2])*256 + mem[1])*256 + mem[0]
59         end = ((mem[7]*256 + mem[6])*256 + mem[5])*256 + mem[4]
60         return (start, end)
61
62     def download(self, start, data):
63         addr = start
64         addr_end = (start + len(data)) & 0xffffff00
65         i = (addr - 0x08000000) / 0x100
66         j = 0
67         print "start %08x" % addr
68         print "end   %08x" % addr_end
69         while addr < addr_end:
70             print "# %08x: %d: %d : %d" % (addr, i, j, 256)
71             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
72                                         value = 0, index = 0,
73                                         buffer = data[j*256:j*256+256],
74                                         timeout = 10000)
75             crc32code = crc32(data[j*256:j*256+256])
76             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
77                                               value = 0, index = 0, buffer = 4,
78                                               timeout = 10000)
79             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
80             if (crc32code ^ r_value) != 0xffffffff:
81                 print "failure"
82             self.__devhandle.controlMsg(requestType = 0x40, request = 3,
83                                         value = i, index = 0,
84                                         buffer = None,
85                                         timeout = 10000)
86             time.sleep(0.010)
87             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
88                                               value = 0, index = 0, buffer = 4,
89                                               timeout = 10000)
90             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
91             if r_value == 0:
92                 print "failure"
93             i = i+1
94             j = j+1
95             addr = addr + 256
96         residue = len(data) % 256
97         if residue != 0:
98             print "# %08x: %d : %d" % (addr, i, residue)
99             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
100                                         value = 0, index = 0,
101                                         buffer = data[j*256:],
102                                         timeout = 10000)
103             crc32code = crc32(data[j*256:].ljust(256,chr(255)))
104             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
105                                               value = 0, index = 0, buffer = 4,
106                                               timeout = 10000)
107             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
108             if (crc32code ^ r_value) != 0xffffffff:
109                 print "failure"
110             self.__devhandle.controlMsg(requestType = 0x40, request = 3,
111                                         value = i, index = 0,
112                                         buffer = None,
113                                         timeout = 10000)
114             time.sleep(0.010)
115             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
116                                               value = 0, index = 0, buffer = 4,
117                                               timeout = 10000)
118             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
119             if r_value == 0:
120                 print "failure"
121
122     def protect(self):
123         self.__devhandle.controlMsg(requestType = 0x40, request = 4,
124                                     value = 0, index = 0, buffer = None,
125                                     timeout = 10000)
126         time.sleep(0.100)
127         res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
128                                           value = 0, index = 0, buffer = 4,
129                                           timeout = 10000)
130         r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
131         if r_value == 0:
132             print "protection failure"
133
134     def finish(self):
135         self.__devhandle.controlMsg(requestType = 0x40, request = 5,
136                                     value = 0, index = 0, buffer = None,
137                                     timeout = 10000)
138
139     def reset_device(self):
140         try:
141             self.__devhandle.reset()
142         except:
143             pass
144
145 class neug(object):
146     def __init__(self, device, configuration, interface):
147         """
148         __init__(device, configuration, interface) -> None
149         Initialize the device.
150         device: usb.Device object.
151         configuration: configuration number.
152         interface: usb.Interface object representing the interface and altenate setting.
153         """
154         if interface.interfaceClass !=COM_CLASS:
155             raise ValueError, "Wrong interface class"
156         if interface.interfaceSubClass != COM_SUBCLASS:
157             raise ValueError, "Wrong interface sub class"
158         self.__devhandle = device.open()
159         # self.__devhandle.detachKernelDriver(interface)
160         # self.__devhandle.claimInterface(interface)
161         # self.__devhandle.setAltInterface(interface)
162
163         self.__intf = interface.interfaceNumber
164         self.__alt = interface.alternateSetting
165         self.__conf = configuration
166
167         self.__timeout = 10000
168
169     def reset_device(self):
170         try:
171             self.__devhandle.reset()
172         except:
173             pass
174
175     def stop_neug(self):
176         self.__devhandle.controlMsg(requestType = 0x40, request = 255,
177                                     value = 0, index = 0, buffer = None,
178                                     timeout = 10)
179         # self.__devhandle.releaseInterface()
180         # self.__devhandle.setConfiguration(0)
181         return
182
183     def mem_info(self):
184         mem = self.__devhandle.controlMsg(requestType = 0xc0, request = 0,
185                                           value = 0, index = 0, buffer = 8,
186                                           timeout = 10)
187         start = ((mem[3]*256 + mem[2])*256 + mem[1])*256 + mem[0]
188         end = ((mem[7]*256 + mem[6])*256 + mem[5])*256 + mem[4]
189         return (start, end)
190
191     def download(self, start, data):
192         addr = start
193         addr_end = (start + len(data)) & 0xffffff00
194         i = (addr - 0x20000000) / 0x100
195         j = 0
196         print "start %08x" % addr
197         print "end   %08x" % addr_end
198         while addr < addr_end:
199             print "# %08x: %d : %d" % (addr, i, 256)
200             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
201                                         value = i, index = 0,
202                                         buffer = data[j*256:j*256+256],
203                                         timeout = 10)
204             i = i+1
205             j = j+1
206             addr = addr + 256
207         residue = len(data) % 256
208         if residue != 0:
209             print "# %08x: %d : %d" % (addr, i, residue)
210             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
211                                         value = i, index = 0,
212                                         buffer = data[j*256:],
213                                         timeout = 10)
214
215     def execute(self, last_addr):
216         i = (last_addr - 0x20000000) / 0x100
217         o = (last_addr - 0x20000000) % 0x100
218         self.__devhandle.controlMsg(requestType = 0x40, request = 2,
219                                     value = i, index = o, buffer = None,
220                                     timeout = 10)
221
222 def compare(data_original, data_in_device):
223     i = 0 
224     for d in data_original:
225         if ord(d) != data_in_device[i]:
226             raise ValueError, "verify failed at %08x" % i
227         i += 1
228
229 def com_devices():
230     busses = usb.busses()
231     for bus in busses:
232         devices = bus.devices
233         for dev in devices:
234             for config in dev.configurations:
235                 for intf in config.interfaces:
236                     for alt in intf:
237                         if alt.interfaceClass == COM_CLASS and \
238                                 alt.interfaceSubClass == COM_SUBCLASS and \
239                                 alt.interfaceProtocol == COM_PROTOCOL_0:
240                             yield dev, config, alt
241
242 USB_VENDOR_FSIJ=0x234b
243 USB_PRODUCT_GNUK=0x0000
244
245 def gnuk_devices():
246     busses = usb.busses()
247     for bus in busses:
248         devices = bus.devices
249         for dev in devices:
250             if dev.idVendor != USB_VENDOR_FSIJ:
251                 continue
252             if dev.idProduct != USB_PRODUCT_GNUK:
253                 continue
254             yield dev
255
256 def to_string(t):
257     result = ""
258     for c in t:
259         result += chr(c)
260     return result
261
262 def UNSIGNED(n):
263     return n & 0xffffffff
264
265 def crc32(bytestr):
266     crc = binascii.crc32(bytestr)
267     return UNSIGNED(crc)
268
269 def main(data_regnual, data_upgrade):
270     l = len(data_regnual)
271     if (l & 0x03) != 0:
272         data_regnual = data_regnual.ljust(l + 4 - (l & 0x03), chr(0))
273     crc32code = crc32(data_regnual)
274     print "CRC32: %04x\n" % crc32code
275     data_regnual += pack('<I', crc32code)
276     com = None
277     for (dev, config, intf) in com_devices():
278         try:
279             com = neug(dev, config, intf)
280             print "Device: ", dev.filename
281             print "Configuration: ", config.value
282             print "Interface: ", intf.interfaceNumber
283             break
284         except:
285             pass
286     if not com:
287         raise ValueError, "No NeuG Device Present"
288     com.stop_neug()
289     time.sleep(0.500)
290     mem_info = com.mem_info()
291     print "%08x:%08x" % mem_info
292     print "Downloading flash upgrade program..."
293     com.download(mem_info[0], data_regnual)
294     print "Run flash upgrade program..."
295     com.execute(mem_info[0] + len(data_regnual) - 4)
296     #
297     time.sleep(3)
298     com.reset_device()
299     del com
300     com = None
301     #
302     print "Wait 3 seconds..."
303     time.sleep(3)
304     # Then, send upgrade program...
305     reg = None
306     for dev in gnuk_devices():
307         try:
308             reg = regnual(dev)
309             print "Device: ", dev.filename
310             break
311         except:
312             pass
313     mem_info = reg.mem_info()
314     print "%08x:%08x" % mem_info
315     print "Downloading the program"
316     reg.download(mem_info[0], data_upgrade)
317     reg.protect()
318     reg.finish()
319     reg.reset_device()
320     return 0
321
322
323 if __name__ == '__main__':
324     filename_regnual = sys.argv[1]
325     filename_upgrade = sys.argv[2]
326     f = open(filename_regnual)
327     data_regnual = f.read()
328     f.close()
329     print "%s: %d" % (filename_regnual, len(data_regnual))
330     f = open(filename_upgrade)
331     data_upgrade = f.read()
332     f.close()
333     print "%s: %d" % (filename_upgrade, len(data_upgrade))
334     main(data_regnual, data_upgrade[4096:])