티스토리 뷰

Softmax classification



Tensorflow로 구현

-4개의 변수에 의해 3가지 등급으로 분류하는 학습


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
 
# x_data : 4개의 변수로 구성
# y_data : one-hot 방식으로 3가지 label
# [1, 0, 0] = 첫 번째
# [0, 1, 0] = 두 번째
# [0, 0, 1] = 세 번째
x_data = [[1211], [2132], [3134], [4155],
          [1755], [1256], [1666], [1777]]
y_data = [[001], [001], [001], [010],
          [010], [010], [100], [100]]
 
# label의 개수 = y_data 분류 개수
nb_classes = 3
 
# placeholder
= tf.placeholder(tf.float32, shape=[None, 4])
= tf.placeholder(tf.float32, shape=[None, 3])
 
# x_data의 변수가 4개이므로, W도 4개 / y_label의 분류가 3개이므로, binary classification 3번
= tf.Variable(tf.random_normal([4, nb_classes]), name="weight")
= tf.Variable(tf.random_normal([nb_classes]), name="bias")
 
# Hypothesis : softmax function 사용
# softmax = exp(logits) / reduce_sum(exp(logits))
hypothesis = tf.nn.softmax(tf.matmul(X, W) + b)
 
# cost/loss function : cross entropy
# axis = 1는 matmul이 아닌 같은 element의 곱을 의미
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))
 
# Minimize : Gradient Descent 사용
train = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)
 
# 세션 시작
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
 
    for step in range(2001):
        sess.run(train, feed_dict={X: x_data, Y: y_data})
        if step % 200 == 0:
            print(step, sess.run(cost, feed_dict={X: x_data, Y: y_data}))
            # 2000 0.16968
 
    # Test
    # tf.arg_max() = one-hot encoding으로 가장 큰 값의 index를 return
    a = sess.run(hypothesis, feed_dict={X: [[11179]]})
    print(a, "\n결과:", sess.run(tf.arg_max(a, 1)))
    # [[3.48425168e-03   9.96506214e-01   9.58935289e-06]]
    # 결과: [1]
 
    all = sess.run(hypothesis, feed_dict={X: [[11179], [1343], [1101]]})
    print(all, "\n결과:", sess.run(tf.arg_max(all, 1)))
    # [[2.99357832e-03   9.96996760e-01   9.61958904e-06]
    #  [8.89271736e-01   9.94489938e-02   1.12793017e-02]
    # [9.41215550e-09
    # 3.29720846e-04
    # 9.99670267e-01]]
    # 결과: [1 0 2]
 
cs


댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/02   »
1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28
글 보관함