我们的OCR系统主要由5部分组成,分别写在5个文件之中。它们分别是:
- 客户端(ocr.js)
- 服务器(server.py)
- 简单的用户界面(ocr.html)
- 基于反向传播训练的ANN(ocr.py)
- ANN的实现脚本(neuralnetworkdesign.py)
虽然界面服务器用户界面不是我们的重点,但由于笔者水平有限,必须要从这些地方撸起。首先看最简单的用户界面,它将是我们使用的入口
<html> <head> <script src="ocr.js"></script> <link rel="stylesheet" type="text/css" href="ocr.css"> </head> <body > <div id="main-container" style="text-align: center;"> <h1>OCR Demo</h1> <canvas id="canvas" width="200" height="200"></canvas> <form name="input"> <p>Digit: <input id="digit" type="text"> </p> <input type="button" value="Train" onclick="ocrDemo.train()"> <input type="button" value="Test" onclick="ocrDemo.test()"> <input type="button" value="Reset" onclick="ocrDemo.resetCanvas();"/> </form> </div> </body> </html>
canvas是一个必须由脚本绘制的图形容器。除此之外我们定义了三个按钮调用js的处理函数
以下是完整的js脚本ocr.js
/** * This module creates a 200x200 pixel canvas for a user to draw * digits. The digits can either be used to train the neural network * or to test the network‘s current prediction for that digit. * * To simplify computation, the 200x200px canvas is translated as a 20x20px * canvas to be processed as an input array of 1s (white) and 0s (black) on * on the server side. Each new translated pixel‘s size is 10x10px * * When training the network, traffic to the server can be reduced by batching * requests to train based on BATCH_SIZE. */ var ocrDemo = { CANVAS_WIDTH: 200, TRANSLATED_WIDTH: 20, PIXEL_WIDTH: 10, // TRANSLATED_WIDTH = CANVAS_WIDTH / PIXEL_WIDTH BATCH_SIZE: 1, // Server Variables PORT: "8000", HOST: "http://localhost", // Colors BLACK: "#000000", BLUE: "#0000ff", trainArray: [], trainingRequestCount: 0, onLoadFunction: function() { this.resetCanvas(); }, resetCanvas: function() { var canvas = document.getElementById(‘canvas‘); var ctx = canvas.getContext(‘2d‘); this.data = []; ctx.fillStyle = this.BLACK; ctx.fillRect(0, 0, this.CANVAS_WIDTH, this.CANVAS_WIDTH); var matrixSize = 400; while (matrixSize--) this.data.push(0); this.drawGrid(ctx); canvas.onmousemove = function(e) { this.onMouseMove(e, ctx, canvas) }.bind(this); canvas.onmousedown = function(e) { this.onMouseDown(e, ctx, canvas) }.bind(this); canvas.onmouseup = function(e) { this.onMouseUp(e, ctx) }.bind(this); }, drawGrid: function(ctx) { for (var x = this.PIXEL_WIDTH, y = this.PIXEL_WIDTH; x < this.CANVAS_WIDTH; x += this.PIXEL_WIDTH, y += this.PIXEL_WIDTH) { ctx.strokeStyle = this.BLUE; ctx.beginPath(); ctx.moveTo(x, 0); ctx.lineTo(x, this.CANVAS_WIDTH); ctx.stroke(); ctx.beginPath(); ctx.moveTo(0, y); ctx.lineTo(this.CANVAS_WIDTH, y); ctx.stroke(); } }, onMouseMove: function(e, ctx, canvas) { if (!canvas.isDrawing) { return; } this.fillSquare(ctx, e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); }, onMouseDown: function(e, ctx, canvas) { canvas.isDrawing = true; this.fillSquare(ctx, e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); }, onMouseUp: function(e) { canvas.isDrawing = false; }, fillSquare: function(ctx, x, y) { var xPixel = Math.floor(x / this.PIXEL_WIDTH); var yPixel = Math.floor(y / this.PIXEL_WIDTH); this.data[((xPixel - 1) * this.TRANSLATED_WIDTH + yPixel) - 1] = 1; ctx.fillStyle = ‘#ffffff‘; ctx.fillRect(xPixel * this.PIXEL_WIDTH, yPixel * this.PIXEL_WIDTH, this.PIXEL_WIDTH, this.PIXEL_WIDTH); }, train: function() { var digitVal = document.getElementById("digit").value; if (!digitVal || this.data.indexOf(1) < 0) { alert("Please type and draw a digit value in order to train the network"); return; } this.trainArray.push({"y0": this.data, "label": parseInt(digitVal)}); this.trainingRequestCount++; // Time to send a training batch to the server. if (this.trainingRequestCount == this.BATCH_SIZE) { alert("Sending training data to server..."); var json = { trainArray: this.trainArray, train: true }; this.sendData(json); this.trainingRequestCount = 0; this.trainArray = []; } }, test: function() { if (this.data.indexOf(1) < 0) { alert("Please draw a digit in order to test the network"); return; } var json = { image: this.data, predict: true }; this.sendData(json); }, receiveResponse: function(xmlHttp) { if (xmlHttp.status != 200) { alert("Server returned status " + xmlHttp.status); return; } var responseJSON = JSON.parse(xmlHttp.responseText); if (xmlHttp.responseText && responseJSON.type == "test") { alert("The neural network predicts you wrote a \‘" + responseJSON.result + ‘\‘‘); } }, onError: function(e) { alert("Error occurred while connecting to server: " + e.target.statusText); }, sendData: function(json) { var xmlHttp = new XMLHttpRequest(); xmlHttp.open(‘POST‘, this.HOST + ":" + this.PORT, false); xmlHttp.onload = function() { this.receiveResponse(xmlHttp); }.bind(this); xmlHttp.onerror = function() { this.onError(xmlHttp) }.bind(this); var msg = JSON.stringify(json); xmlHttp.setRequestHeader(‘Content-length‘, msg.length); xmlHttp.setRequestHeader("Connection", "close"); xmlHttp.send(msg); } }
虽然javascript本来是不支持类的,但可以用“极简主义法”的方式定义类,参看:http://www.ruanyifeng.com/blog/2012/07/three_ways_to_define_a_javascript_class.html
如我们的var ocrDemo就可以看作一个以极简主义法定义的类。而this指针是js语言的一个关键字,它在函数调用的时候自动生成,并且它总是指向调用函数的那个对象
如此,结合canvas的一些方法,画画过程就不难看懂了:我们把10*10的一个真实像素化为一个我们的像素,为网格填充颜色后,监听鼠标动作,每次点击和move调用
fillSquare 函数填充一个方块
解决了画画问题,下一步要将数据传输到服务器,让它进行相关的学习,其中
sendData: function(json) { var xmlHttp = new XMLHttpRequest(); xmlHttp.open(‘POST‘, this.HOST + ":" + this.PORT, false); console.log(this.HOST+":"+this.PORT) xmlHttp.onload = function() { this.receiveResponse(xmlHttp); }.bind(this); xmlHttp.onerror = function() { this.onError(xmlHttp) }.bind(this); var msg = JSON.stringify(json); xmlHttp.setRequestHeader(‘Content-length‘, msg.length); xmlHttp.setRequestHeader("Connection", "close"); console.log("fuck") xmlHttp.send(msg); }
xmlhttprequest对象可以用于幕后和服务器交换数据
该对象有如下方法:
abort() | 取消当前的请求。 |
getAllResponseHeaders() | 返回头信息。 |
getResponseHeader() | 返回指定的头信息。 |
open(method,url,async,uname,pswd) | 规定请求的类型,URL,请求是否应该进行异步处理,以及请求的其他可选属性。
method:请求的类型:GET 或 POST |
send(string) | 发送请求到服务器。
string:仅用于 POST 请求 |
setRequestHeader() | 把标签/值对添加到要发送的头文件。 |
值得注意的是open方法中的同步(false)异步(true)请求,如果设置为同步,那么在未收到返回数据之前,浏览器页面是不能进行其他操作的。如果设置为同步,则可以进行其他操作,但服务器返回的数据可能收不到了。笔者没有在windows下装科学计算环境,服务器挂在了阿里云上。这又带来了一些问题:
1.由于需要实时响应,向远程服务器发送的速度非常慢
2.跨域访问通常情况下被拒绝
3.服务器端若不关闭防火墙将不能收到post请求
什么是跨域访问呢?从一个域名的网页访问另一网页的资源时,只要协议域名端口有任何一个不同,就被称为跨域访问。跨域访问由于一些安全问题,通常情况是拒绝的
我们需要在服务器端处理post请求时添加报文头部 s.send_header("Access-Control-Allow-Origin", "*")
其中*代表任意,即我们允许任意域访问。这其中又有一个细节,当跨域访问无权限时,服务端还是能够收到请求报文的,但不会对它进行处理
服务器部分源码如下:
import BaseHTTPServer import json from ocr import OCRNeuralNetwork import numpy as np HOST_NAME = ‘localhost‘ PORT_NUMBER = 8000 HIDDEN_NODE_COUNT = 15 # Load data samples and labels into matrix data_matrix = np.loadtxt(open(‘data.csv‘, ‘rb‘), delimiter = ‘,‘) data_labels = np.loadtxt(open(‘dataLabels.csv‘, ‘rb‘)) # Convert from numpy ndarrays to python lists data_matrix = data_matrix.tolist() data_labels = data_labels.tolist() # If a neural network file does not exist, train it using all 5000 existing data samples. # Based on data collected from neural_network_design.py, 15 is the optimal number # for hidden nodes nn = OCRNeuralNetwork(HIDDEN_NODE_COUNT, data_matrix, data_labels, list(range(5000))); class JSONHandler(BaseHTTPServer.BaseHTTPRequestHandler): def do_POST(s): print "fuck" response_code = 200 response = "" var_len = int(s.headers.get(‘Content-Length‘)) content = s.rfile.read(var_len); payload = json.loads(content); if payload.get(‘train‘): nn.train(payload[‘trainArray‘]) nn.save() elif payload.get(‘predict‘): try: response = {"type":"test", "result":nn.predict(str(payload[‘image‘]))} except: response_code = 500 else: response_code = 400 s.send_response(response_code) s.send_header("Content-type", "application/json") s.send_header("Access-Control-Allow-Origin", "*") s.end_headers() if response: s.wfile.write(json.dumps(response)) return Page = ‘‘‘ <html> <body> <p>Hello, web!</p> </body> </html> ‘‘‘ # Handle a GET request. def do_GET(self): self.send_response(200) self.send_header("Content-type", "text/html") self.send_header("Content-Length", str(len(self.Page))) self.end_headers() self.wfile.write(self.Page) if __name__ == ‘__main__‘: server_class = BaseHTTPServer.HTTPServer; httpd = server_class((HOST_NAME, PORT_NUMBER), JSONHandler) try: httpd.serve_forever() except KeyboardInterrupt: pass else: print "Unexpected server exception occurred." finally: httpd.server_close()
逻辑也不难理解。先预处理数据,然后使用basehttpserver开启服务器,重写post请求
aseHTTPRequestHandler 实例有下列方法:
handle()
调用 handle_one_request()一次 (或,如果能够持续连接,多次) 处理进来的 HTTP 请求。你从不需要重载它;替代,实现对应的 do_*() 方法。
handle_one_request()
这个方法将解析和分派请求到对应的 do_*() 方法。你从不需要重载它。
send_error(code[, message])
发送并记录一个完整的错误回复到客户端。数字的 code 指定 HTTP 错误代码,以 message 作为可选的,更多指定的文本。全套的头被发送,后面紧跟使用 the error_message_format 类变量组成的文本。
send_response(code[, message])
发送一个响应头并记录已接收的请求。HTTP 响应行被发送,后面紧跟 Server 和 Date 头。这两个头的值分别地从 version_string() 和 date_time_string() 方法中获得。
send_header(keyword, value)
编写一个指定的 HTTP 头到输出流。 keyword 应该指定头关键字,value 指定它的值。
end_headers()
发送一个空白行,表示响应中的 HTTP 头结束。
log_request([code[, size]])
记录一个已接收的 (成功的) 请求。code 指定关联响应的数字的 HTTP 代码。如果响应的大小可用,那么它应该作为 size 参数被传递。
log_error(...)
当一个请求不能被完成时记录一个错误。缺省,它传递信息给 log_message(),因此它取相同的参数 (format 和 附加值)。
log_message( format,...)
记录一个随机信息给 sys.stderr。典型地重载创建自定义的错误日志结构。 format 参数是一个标准的printf风格的格式化字符串,附加参数给 log_message() 用于输出格式。客户端地址和当前的日期时间被作为记录的每个信息的前缀。
version_string()
返回服务器软件的版本字符串。这是一个 server_version 和 sys_version 类变量的联合。
date_time_string([timestamp])
返回通过 timestamp 给定的日期和时间(必须是由 time.time()返回的格式),格式化一个信息头。如果 timestamp 被省略,它使用当前的日期和时间。
结果像 ‘Sun, 06 Nov 1994 08:49:37 GMT‘。2.5 版本中的新特性: timestamp 参数。
log_date_time_string()
返回当前的日期和时间,格式化日志。
address_string()
返回客户端地址,格式化日志。一个名称的查找被执行在客户端的IP地址上。
如此基础部分就解读完成啦