scikit-learnを使えば簡単な機械学習ならすぐにできる。
今回は特徴ベクトルとして画素値をもちい、それをk-NearestNeighborsにぶちこむ。
k-meansもためしたけど、こちらはあまりうまく行かなかった。
Pythonだと学習した識別器(オブジェクト)をPickleで保存して簡単に取り出せる。
10種類くらいのフォントを学習させて11種類目でためしたが、正答率は100%だった。今後もっといろんな状況下(画素が少ないサンプルとか)で試していってどこまでいけるか確かめたい。
以下メモ
学習部分
予測部分
今回は特徴ベクトルとして画素値をもちい、それをk-NearestNeighborsにぶちこむ。
k-meansもためしたけど、こちらはあまりうまく行かなかった。
Pythonだと学習した識別器(オブジェクト)をPickleで保存して簡単に取り出せる。
10種類くらいのフォントを学習させて11種類目でためしたが、正答率は100%だった。今後もっといろんな状況下(画素が少ないサンプルとか)で試していってどこまでいけるか確かめたい。
以下メモ
学習部分
import numpy as np import cv2 from sklearn.neighbors import KNeighborsClassifier from sampling import convert_to_binary """ 1. Read sample image and convert to 1d feature vector 2. pass the feature vectors to kmeans clustering maching 3. pickle the result. """ def convert_to_feature_vector(src): """ convert source image to a 1d feature vector """ N =10 im = cv2.GaussianBlur(src,(5,5),0) im = cv2.resize(src, (N, N)) im = im.reshape(N**2) #im = np.array(im>124, dtype=np.int8) #convert to 0 and 1 return im if __name__ == "__main__": #read jpgs, and resize them as a vector dirname = [str(i) for i in range(10)] dirname += ["dot"] dirname += ["bar"] X = [] #sample data Y = [] #label for dn in dirname: #input label(0-11) if dn == "bar": label = 11 elif dn == "dot": label = 10 else: label = int(dn) fnames = os.listdir(dn) fnames.remove(".DS_Store") for fn in fnames: #print os.path.join(dn, fn) im = cv2.imread(os.path.join(dn, fn)) im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) im = convert_to_feature_vector(im) #print im.reshape(10,10) X.append(im) Y.append(label) #convert to np.array X = np.array(X) # knn = KNeighborsClassifier(n_neighbors=5) knn.fit(X, Y) #save the classifier as pickle import pickle with open("knn_trained.dump", "w") as f: pickle.dump(knn, f)
予測部分
#coding: utf-8 from sklearn import cluster import pickle import cv2 from training_knn import convert_to_feature_vector if __name__ == "__main__": with open("knn_trained.dump") as f: knn = pickle.load(f) print knn import os dirname = "test_data" fnames = os.listdir(dirname) try: fnames.remove(".DS_Store") except ValueError as e: print e for fn in fnames: im = cv2.imread(os.path.join(dirname, fn)) im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) x = convert_to_feature_vector(im) y = knn.predict([x]) print fn, y
0 件のコメント:
コメントを投稿