support reGNUal
[gnuk/neug.git] / tool / neug_upgrade.py
1 #! /usr/bin/python
2
3 """
4 neug_upgrade.py - a tool to upgrade firmware of Gnuk Token
5
6 Copyright (C) 2012 Free Software Initiative of Japan
7 Author: NIIBE Yutaka <gniibe@fsij.org>
8
9 This file is a part of NeuG, a TRNG implementation.
10
11 Gnuk is free software: you can redistribute it and/or modify it
12 under the terms of the GNU General Public License as published by
13 the Free Software Foundation, either version 3 of the License, or
14 (at your option) any later version.
15
16 Gnuk is distributed in the hope that it will be useful, but WITHOUT
17 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
18 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
19 License for more details.
20
21 You should have received a copy of the GNU General Public License
22 along with this program.  If not, see <http://www.gnu.org/licenses/>.
23 """
24
25 from struct import *
26 import sys, time, os, binascii, string
27
28 # INPUT: binary file
29
30 # Assume only single NeuG device is attached to computer
31
32 import usb
33
34 # USB class, subclass, protocol
35 COM_CLASS = 0x0a
36 COM_SUBCLASS = 0x00
37 COM_PROTOCOL_0 = 0x00
38
39 class regnual(object):
40     def __init__(self, dev):
41         conf = dev.configurations[0]
42         intf_alt = conf.interfaces[0]
43         intf = intf_alt[0]
44         if intf.interfaceClass != 0xff:
45             raise ValueError, "Wrong interface class"
46         self.__devhandle = dev.open()
47         try:
48             self.__devhandle.setConfiguration(conf)
49         except:
50             pass
51         self.__devhandle.claimInterface(intf)
52         self.__devhandle.setAltInterface(intf)
53
54     def mem_info(self):
55         mem = self.__devhandle.controlMsg(requestType = 0xc0, request = 0,
56                                           value = 0, index = 0, buffer = 8,
57                                           timeout = 10000)
58         start = ((mem[3]*256 + mem[2])*256 + mem[1])*256 + mem[0]
59         end = ((mem[7]*256 + mem[6])*256 + mem[5])*256 + mem[4]
60         return (start, end)
61
62     def download(self, start, data):
63         addr = start
64         addr_end = (start + len(data)) & 0xffffff00
65         i = (addr - 0x08000000) / 0x100
66         j = 0
67         print "start %08x" % addr
68         print "end   %08x" % addr_end
69         while addr < addr_end:
70             print "# %08x: %d: %d : %d" % (addr, i, j, 256)
71             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
72                                         value = 0, index = 0,
73                                         buffer = data[j*256:j*256+256],
74                                         timeout = 10000)
75             crc32code = crc32(data[j*256:j*256+256])
76             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
77                                               value = 0, index = 0, buffer = 4,
78                                               timeout = 10000)
79             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
80             if (crc32code ^ r_value) != 0xffffffff:
81                 print "failure"
82             self.__devhandle.controlMsg(requestType = 0x40, request = 3,
83                                         value = i, index = 0,
84                                         buffer = None,
85                                         timeout = 10000)
86             time.sleep(0.010)
87             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
88                                               value = 0, index = 0, buffer = 4,
89                                               timeout = 10000)
90             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
91             if r_value == 0:
92                 print "failure"
93             i = i+1
94             j = j+1
95             addr = addr + 256
96         residue = len(data) % 256
97         if residue != 0:
98             print "# %08x: %d : %d" % (addr, i, residue)
99             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
100                                         value = 0, index = 0,
101                                         buffer = data[j*256:],
102                                         timeout = 10000)
103             crc32code = crc32(data[j*256:].ljust(256,chr(255)))
104             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
105                                               value = 0, index = 0, buffer = 4,
106                                               timeout = 10000)
107             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
108             if (crc32code ^ r_value) != 0xffffffff:
109                 print "failure"
110             self.__devhandle.controlMsg(requestType = 0x40, request = 3,
111                                         value = i, index = 0,
112                                         buffer = None,
113                                         timeout = 10000)
114             time.sleep(0.010)
115             res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
116                                               value = 0, index = 0, buffer = 4,
117                                               timeout = 10000)
118             r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
119             if r_value == 0:
120                 print "failure"
121
122     def protect(self):
123         self.__devhandle.controlMsg(requestType = 0x40, request = 4,
124                                     value = 0, index = 0, buffer = None,
125                                     timeout = 10000)
126         time.sleep(0.100)
127         res = self.__devhandle.controlMsg(requestType = 0xc0, request = 2,
128                                           value = 0, index = 0, buffer = 4,
129                                           timeout = 10000)
130         r_value = ((res[3]*256 + res[2])*256 + res[1])*256 + res[0]
131         if r_value == 0:
132             print "protection failure"
133
134     def finish(self):
135         self.__devhandle.controlMsg(requestType = 0x40, request = 5,
136                                     value = 0, index = 0, buffer = None,
137                                     timeout = 10000)
138
139     def reset_device(self):
140         try:
141             self.__devhandle.reset()
142         except:
143             pass
144
145 class neug(object):
146     def __init__(self, device, configuration, interface):
147         """
148         __init__(device, configuration, interface) -> None
149         Initialize the device.
150         device: usb.Device object.
151         configuration: configuration number.
152         interface: usb.Interface object representing the interface and altenate setting.
153         """
154         if interface.interfaceClass !=COM_CLASS:
155             raise ValueError, "Wrong interface class"
156         if interface.interfaceSubClass != COM_SUBCLASS:
157             raise ValueError, "Wrong interface sub class"
158         self.__devhandle = device.open()
159         try:
160             self.__devhandle.setConfiguration(configuration)
161         except:
162             pass
163
164         self.__devhandle.detachKernelDriver(interface)
165         self.__devhandle.claimInterface(interface)
166         self.__devhandle.setAltInterface(interface)
167
168         self.__intf = interface.interfaceNumber
169         self.__alt = interface.alternateSetting
170         self.__conf = configuration
171
172         self.__timeout = 10000
173
174     def reset_device(self):
175         try:
176             self.__devhandle.reset()
177         except:
178             pass
179
180     def stop_neug(self):
181         self.__devhandle.controlMsg(requestType = 0x40, request = 255,
182                                     value = 0, index = 0, buffer = None,
183                                     timeout = 10)
184         self.__devhandle.releaseInterface()
185         self.__devhandle.setConfiguration(0)
186         return
187
188     def mem_info(self):
189         mem = self.__devhandle.controlMsg(requestType = 0xc0, request = 0,
190                                           value = 0, index = 0, buffer = 8,
191                                           timeout = 10)
192         start = ((mem[3]*256 + mem[2])*256 + mem[1])*256 + mem[0]
193         end = ((mem[7]*256 + mem[6])*256 + mem[5])*256 + mem[4]
194         return (start, end)
195
196     def download(self, start, data):
197         addr = start
198         addr_end = (start + len(data)) & 0xffffff00
199         i = (addr - 0x20000000) / 0x100
200         j = 0
201         print "start %08x" % addr
202         print "end   %08x" % addr_end
203         while addr < addr_end:
204             print "# %08x: %d : %d" % (addr, i, 256)
205             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
206                                         value = i, index = 0,
207                                         buffer = data[j*256:j*256+256],
208                                         timeout = 10)
209             i = i+1
210             j = j+1
211             addr = addr + 256
212         residue = len(data) % 256
213         if residue != 0:
214             print "# %08x: %d : %d" % (addr, i, residue)
215             self.__devhandle.controlMsg(requestType = 0x40, request = 1,
216                                         value = i, index = 0,
217                                         buffer = data[j*256:],
218                                         timeout = 10)
219
220     def execute(self, last_addr):
221         i = (last_addr - 0x20000000) / 0x100
222         o = (last_addr - 0x20000000) % 0x100
223         self.__devhandle.controlMsg(requestType = 0x40, request = 2,
224                                     value = i, index = o, buffer = None,
225                                     timeout = 10)
226
227 def compare(data_original, data_in_device):
228     i = 0 
229     for d in data_original:
230         if ord(d) != data_in_device[i]:
231             raise ValueError, "verify failed at %08x" % i
232         i += 1
233
234 def com_devices():
235     busses = usb.busses()
236     for bus in busses:
237         devices = bus.devices
238         for dev in devices:
239             for config in dev.configurations:
240                 for intf in config.interfaces:
241                     for alt in intf:
242                         if alt.interfaceClass == COM_CLASS and \
243                                 alt.interfaceSubClass == COM_SUBCLASS and \
244                                 alt.interfaceProtocol == COM_PROTOCOL_0:
245                             yield dev, config, alt
246
247 USB_VENDOR_FSIJ=0x234b
248 USB_PRODUCT_GNUK=0x0000
249
250 def gnuk_devices():
251     busses = usb.busses()
252     for bus in busses:
253         devices = bus.devices
254         for dev in devices:
255             if dev.idVendor != USB_VENDOR_FSIJ:
256                 continue
257             if dev.idProduct != USB_PRODUCT_GNUK:
258                 continue
259             yield dev
260
261 def to_string(t):
262     result = ""
263     for c in t:
264         result += chr(c)
265     return result
266
267 def UNSIGNED(n):
268     return n & 0xffffffff
269
270 def crc32(bytestr):
271     crc = binascii.crc32(bytestr)
272     return UNSIGNED(crc)
273
274 def main(data_regnual, data_upgrade):
275     l = len(data_regnual)
276     if (l & 0x03) != 0:
277         data_regnual = data_regnual.ljust(l + 4 - (l & 0x03), chr(0))
278     crc32code = crc32(data_regnual)
279     print "CRC32: %04x\n" % crc32code
280     data_regnual += pack('<I', crc32code)
281     com = None
282     for (dev, config, intf) in com_devices():
283         try:
284             com = neug(dev, config, intf)
285             print "Device: ", dev.filename
286             print "Configuration: ", config.value
287             print "Interface: ", intf.interfaceNumber
288             break
289         except:
290             pass
291     if not com:
292         raise ValueError, "No NeuG Device Present"
293     com.stop_neug()
294     time.sleep(0.500)
295     mem_info = com.mem_info()
296     print "%08x:%08x" % mem_info
297     print "Downloading flash upgrade program..."
298     com.download(mem_info[0], data_regnual)
299     print "Run flash upgrade program..."
300     com.execute(mem_info[0] + len(data_regnual) - 4)
301     #
302     time.sleep(3)
303     com.reset_device()
304     del com
305     com = None
306     #
307     print "Wait 3 seconds..."
308     time.sleep(3)
309     # Then, send upgrade program...
310     reg = None
311     for dev in gnuk_devices():
312         try:
313             reg = regnual(dev)
314             print "Device: ", dev.filename
315             break
316         except:
317             pass
318     mem_info = reg.mem_info()
319     print "%08x:%08x" % mem_info
320     print "Downloading the program"
321     reg.download(mem_info[0], data_upgrade)
322     reg.protect()
323     reg.finish()
324     reg.reset_device()
325     return 0
326
327
328 if __name__ == '__main__':
329     filename_regnual = sys.argv[1]
330     filename_upgrade = sys.argv[2]
331     f = open(filename_regnual)
332     data_regnual = f.read()
333     f.close()
334     print "%s: %d" % (filename_regnual, len(data_regnual))
335     f = open(filename_upgrade)
336     data_upgrade = f.read()
337     f.close()
338     print "%s: %d" % (filename_upgrade, len(data_upgrade))
339     main(data_regnual, data_upgrade[4096:])