{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import tensorflow as tf\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "keys = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']\n", "data = pd.read_csv(\"iris.csv\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
casenoSepalLengthSepalWidthPetalLengthPetalWidthSpecies
015.13.51.40.2setosa
124.93.01.40.2setosa
234.73.21.30.2setosa
344.63.11.50.2setosa
455.03.61.40.2setosa
\n", "
" ], "text/plain": [ " caseno SepalLength SepalWidth PetalLength PetalWidth Species\n", "0 1 5.1 3.5 1.4 0.2 setosa\n", "1 2 4.9 3.0 1.4 0.2 setosa\n", "2 3 4.7 3.2 1.3 0.2 setosa\n", "3 4 4.6 3.1 1.5 0.2 setosa\n", "4 5 5.0 3.6 1.4 0.2 setosa" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "data = data.drop('caseno', axis=1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SepalLengthSepalWidthPetalLengthPetalWidthSpecies
05.13.51.40.2setosa
14.93.01.40.2setosa
24.73.21.30.2setosa
34.63.11.50.2setosa
45.03.61.40.2setosa
\n", "
" ], "text/plain": [ " SepalLength SepalWidth PetalLength PetalWidth Species\n", "0 5.1 3.5 1.4 0.2 setosa\n", "1 4.9 3.0 1.4 0.2 setosa\n", "2 4.7 3.2 1.3 0.2 setosa\n", "3 4.6 3.1 1.5 0.2 setosa\n", "4 5.0 3.6 1.4 0.2 setosa" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "species = list(data['Species'].unique())\n", "data['class'] = data['Species'].map(lambda x: np.eye(len(species))[species.index(x)])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SepalLengthSepalWidthPetalLengthPetalWidthSpeciesclass
1346.12.65.61.4virginica[0.0, 0.0, 1.0]
615.93.04.21.5versicolor[0.0, 1.0, 0.0]
285.23.41.40.2setosa[1.0, 0.0, 0.0]
465.13.81.60.2setosa[1.0, 0.0, 0.0]
165.43.91.30.4setosa[1.0, 0.0, 0.0]
\n", "
" ], "text/plain": [ " SepalLength SepalWidth PetalLength PetalWidth Species \\\n", "134 6.1 2.6 5.6 1.4 virginica \n", "61 5.9 3.0 4.2 1.5 versicolor \n", "28 5.2 3.4 1.4 0.2 setosa \n", "46 5.1 3.8 1.6 0.2 setosa \n", "16 5.4 3.9 1.3 0.4 setosa \n", "\n", " class \n", "134 [0.0, 0.0, 1.0] \n", "61 [0.0, 1.0, 0.0] \n", "28 [1.0, 0.0, 0.0] \n", "46 [1.0, 0.0, 0.0] \n", "16 [1.0, 0.0, 0.0] " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.sample(5)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ "testset = data.sample(50)\n", "trainset = data.drop(testset.index)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [], "source": [ "X = tf.placeholder(tf.float32, [None, 4])\n", "Y = tf.placeholder(tf.float32, [None, 3])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "W = tf.Variable(tf.zeros([4, 3]), name=\"Weight\")\n", "b = tf.Variable(tf.zeros([3]), name=\"Bias\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [], "source": [ "H = tf.nn.softmax(tf.matmul(X, W) + b)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [], "source": [ "cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(H, Y, name=\"Cross_Entropy\"))\n", "optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [], "source": [ "accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(H, 1), tf.argmax(Y, 1)), tf.float32))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [], "source": [ "init = tf.initialize_all_variables()\n", "\n", "sess = tf.Session()\n", "sess.run(init)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 0.31\n", "1.05821\n", "100 0.97\n", "0.700138\n", "200 0.87\n", "0.695254\n", "300 0.88\n", "0.676316\n", "400 0.99\n", "0.611448\n", "500 0.98\n", "0.605681\n", "600 0.98\n", "0.601644\n", "700 0.98\n", "0.598562\n", "800 0.98\n", "0.59612\n", "900 0.98\n", "0.594131\n", "1000 0.98\n", "0.592477\n" ] } ], "source": [ "trainset_class = [y for y in trainset['class'].values]\n", "\n", "for step in xrange(1001):\n", " sess.run(optimizer, feed_dict={X: trainset[keys].values,\n", " Y: trainset_class})\n", " if step % 100 == 0:\n", " print step, sess.run(accuracy, feed_dict={X: trainset[keys].values,\n", " Y: trainset_class})\n", " print sess.run(cross_entropy, feed_dict={X: trainset[keys].values,\n", " Y: trainset_class})" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.98\n", "The flower which has [[ 5. 3.5 1.6 0.6]] may be setosa\n", "Correct!\n" ] } ], "source": [ "# Check which data is error\n", "# trainset_result = sess.run(H, feed_dict={X: trainset[keys].values})\n", "\n", "# error_data = []\n", "\n", "# for x in xrange(trainset.shape[0]):\n", "# if np.argmax(trainset_result[x], 0) != np.argmax(trainset_class[x], 0):\n", "# error_data.append(trainset.values[x])\n", " \n", "# print error_data\n", "\n", "# Check testdata's accuracy\n", "print sess.run(accuracy, feed_dict={X: testset[keys].values,\n", " Y: [y for y in testset['class'].values]})\n", "\n", "# Test which species accords this data\n", "species = [\"setosa\", \"versicolor\", \"virginica\"]\n", "\n", "sample = data.sample(1)\n", "result = species[np.argmax(sess.run(H, feed_dict={X: sample[keys].values}))]\n", "print \"The flower which has\", sample[keys].values, \"may be\", result\n", "if result == sample['Species'].values:\n", " print \"Correct!\"\n", "else:\n", " print \"Incorrect!\"" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12" } }, "nbformat": 4, "nbformat_minor": 1 }