2014年6月3日火曜日

scikit-learnを使って数字認識(2) k-NNを使った学習

scikit-learnを使えば簡単な機械学習ならすぐにできる。
今回は特徴ベクトルとして画素値をもちい、それをk-NearestNeighborsにぶちこむ。
k-meansもためしたけど、こちらはあまりうまく行かなかった。
Pythonだと学習した識別器(オブジェクト)をPickleで保存して簡単に取り出せる。

10種類くらいのフォントを学習させて11種類目でためしたが、正答率は100%だった。今後もっといろんな状況下(画素が少ないサンプルとか)で試していってどこまでいけるか確かめたい。

以下メモ

学習部分
import numpy as np
import cv2
from sklearn.neighbors import KNeighborsClassifier
from sampling import convert_to_binary
"""
1. Read sample image and convert to 1d feature vector
2. pass the feature vectors to kmeans clustering maching
3. pickle the result.
"""

def convert_to_feature_vector(src):
    """
    convert source image to a 1d feature vector
    """
    N =10
    im = cv2.GaussianBlur(src,(5,5),0)
    im = cv2.resize(src, (N, N))
    im = im.reshape(N**2)
    #im = np.array(im>124, dtype=np.int8) #convert to 0 and 1
    return im


if __name__ == "__main__":
    #read jpgs, and resize them as a vector

    dirname = [str(i) for i in range(10)]
    dirname += ["dot"]
    dirname += ["bar"]

    X = [] #sample data
    Y = [] #label
    for dn in dirname:

        #input label(0-11)
        if dn == "bar":
            label = 11
        elif dn == "dot":
            label = 10
        else:
            label = int(dn)


        fnames = os.listdir(dn)
        fnames.remove(".DS_Store")

        for fn in fnames:
            #print os.path.join(dn, fn)
            im = cv2.imread(os.path.join(dn, fn))
            im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
            im = convert_to_feature_vector(im)
            #print im.reshape(10,10)
            X.append(im)
            Y.append(label)

        #convert to np.array
    X = np.array(X)

    #
    knn = KNeighborsClassifier(n_neighbors=5)
    knn.fit(X, Y)

    #save the classifier as pickle
    import pickle
    with open("knn_trained.dump", "w") as f:
        pickle.dump(knn, f)



予測部分
#coding: utf-8

from sklearn import cluster
import pickle
import cv2
from training_knn import convert_to_feature_vector


if __name__ == "__main__":
    with open("knn_trained.dump") as f:
        knn = pickle.load(f)
        print knn

    import os
    dirname = "test_data"
    fnames = os.listdir(dirname)
    try:
        fnames.remove(".DS_Store")
    except ValueError as e:
        print e


    for fn in fnames:
        im = cv2.imread(os.path.join(dirname, fn))
        im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
        x = convert_to_feature_vector(im)
        y = knn.predict([x])
        print fn, y



0 件のコメント:

コメントを投稿