텐서플로우로 머신러닝 하기(bmi 구하기)
2018. 5. 8. 01:26ㆍPython-이론/python-인공지능
텐서플로우로 머신러닝 하기
데이터 파일
위에 첨부된 데이터 파일을 사용하면 된다.
import pandas as pd import tensorflow as tf import numpy as np #hot cool encoding label_zero = {"thin":[1,0,0],"normal":[0,1,0],"fat":[0,0,1]} csv = pd.read_csv("bmi.csv") #정규화 csv["label_pat"]= csv["label"].apply(lambda x :np.array(label_zero[x])) csv["weight"] = csv["weight"].map(lambda x: x/100) csv["height"] = csv["height"].map(lambda x:x/200) x = tf.placeholder(tf.float32,[None,2],name="x") #정답 값 y_ = tf.placeholder(tf.float32,[None,3],name="y") #테스트 데이터 분류 test_csv = csv[15000:20000] test_x = test_csv[["height","weight"]] test_y = list(test_csv["label_pat"]) W = tf.Variable(tf.zeros([2,3]),name = "W") b = tf.Variable(tf.zeros([3]),name="b") sess = tf.Session() sess.run(tf.global_variables_initializer()) #예측 값 #소프트 맥스 회귀 로지스틱 회귀 값을 바탕으로 예측 분류함 y = tf.nn.softmax(tf.matmul(x,W)+b) # 경사하강법을 통해 적절한 매개변수 찾는다. #오차 최소화 cross_entry = -tf.reduce_sum(y_*tf.log(y)) #학습률 0.03 optimizer = tf.train.GradientDescentOptimizer(0.03) train = optimizer.minimize(cross_entry) #arg max 차원중 제일 수가 큰 인덱스 값을 반환한다. predict = tf.equal(tf.argmax(y,1) , tf.argmax(y_,1)) # 차원수를 줄여가며 평균을 구한다. accuracy = tf.reduce_mean(tf.cast(predict,tf.float32)) for step in range(3500): i = (step*100) % 15000 rows = csv[i + 1:i+100] feed_x = rows[["height","weight"]] feed_y = list(rows["label_pat"]) fd = {x:feed_x, y_ : feed_y} sess.run(train,feed_dict=fd) if step % 300 == 0: test_fd = {x:test_x, y_ : test_y} acc = sess.run(accuracy,feed_dict = test_fd) print("step = ",step,"acc= ",acc) acc = sess.run(accuracy, feed_dict={x:test_x,y_:test_y}) print(acc) print('키를 입력해주세요') height = input() print('몸무게를 입력해주세요!') weight = input() inputData = {"height":[int(height)/200],"weight":[int(weight)/100]} inputData = pd.DataFrame(inputData) answer = [[0,0,1]] acc = sess.run(accuracy,feed_dict={x:inputData,y_:answer}) print(acc)
'Python-이론 > python-인공지능' 카테고리의 다른 글
pandas와 numpy 다루기 (0) | 2018.05.11 |
---|---|
Tensorboard 사용해보기 (0) | 2018.05.08 |
데이터 검증하기-cross-validation, grid-search (0) | 2018.04.25 |
randomForest 사용해보기 (0) | 2018.04.25 |
machineLearning의 svm이 무엇이고 직접 그래프로 구분짓기 (0) | 2018.04.22 |