티스토리 뷰

오차역전파

계산 그래프

계산 그래프는 계산 과정을 그래프로 나타낸 것입니다. 자료구조에서 볼 수 있는 그래프 자료구조로, 노드엣지로 표현됩니다. 


계산 그래프의 이점

계산 그래프의 이점은 크게 2가지가 있습니다. 첫 번째는 국소적 계산입니다. 국소적 계산을 전파함으로써 최종 결과를 얻는다는 점입니다. 국소적이란 자신과 직접 관계된 작은 범위라는 뜻입니다. 즉, 독립적이라는 뜻이죠. 국소적 계산을 좀더 이해하기 위해서 계산 그래프를 그려 확인해 보겠습니다.

위 그림을 보면 다른 노드에서 4000원이라는 결과가 어떻게 나왔든지간에 단지 4000과 200을 더하기만 하는 것을 볼 수 있습니다. 각 노드가 자신과 관련한 계산 외에는 아무것도 신경 쓸 게 없습니다.


이처럼 계산 그래프는 국소적 계산에 집중합니다. 전체 계산이 아무리 복잡하더라도 각 단계에서 하는 일은 해당 노드의 국소적 계산입니다. 보기에는 간단하지만 우리 뇌도 이와 비슷하게 작동합니다.


두 번째 이점은 중간 계산 결과를 모두 보관할 수 있다는 점입니다. 이로인해 역전파를 미분으로 효율적으로 계산할 수 있습니다.


계산 그래프와 연쇄법칙

u=f(x)를 나타내는 계산 그래프는 아래 그림과 같습니다.

계산 그래프에서 계산을 왼쪽에서 오른쪽으로 진행하는 단계를 순전파라고 합니다. 반대로 계산을 오른쪽에서 왼쪽으로 진행하는 것을 역전파라고 합니다.


여기서 L/의 의미에 주목해야 합니다. 지금은 노드가 하나지만 실제로는 엄청나게 많습니다. 이 네트워크는 최종적으로는 정답과 비교한 뒤 

Loss를 구합니다.


지금 우리의 목적은 NN의 오차를 줄이는 데 있기 때문에, 각 파라미터별로 Loss에 대한 gradient를 구한 뒤 gradient들이 향한 쪽으로 파라미터들을 업데이트 합니다.


이제는 현재 입력값 x에 대한 Loss의 변화량 즉, L/y을 구할 겁니다. 이는 미분의 연쇄법칙에 의해 다음과 같이 계산됩니다.


이미 설명드렸듯 

L/y는 Loss로부터 흘러들어온 그래디언트입니다. y/x는 현재 입력값에 대한 현재 연산결과의 변화량, 즉 로컬 그래디언트(Local Gradient)입니다.

다시 말해 현재 입력값에 대한 Loss의 변화량은 Loss로부터 흘러들어온 그래디언트에 로컬 그래디언트를 곱해서 구한다는 이야기입니다. 이 그래디언트는 다시 앞쪽에 배치돼 있는 노드로 역전파됩니다.

덧셈 노드

덧셈 노드의 수식은 아래와 같습니다.

z=f(x,y)=x+y

덧셈 노드의 로컬 그래디언트는 아래와 같습니다.

zx=(x+y)x=1zy=(x+y)y=1

덧셈 노드의 계산그래프는 아래와 같습니다. 현재 입력값에 대한 Loss의 변화량은 로컬 그래디언트에 흘러들어온 그래디언트를 각각 곱해주면 됩니다. 덧셈 노드의 역전파는 흘러들어온 그래디언트를 그대로 흘려보내는 걸 확인할 수 있습니다.

덧셉 노드를 파이썬으로 구현해 보겠습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class AddLayer:
    def __init__(self):
        pass
 
    def forward(self, x, y):
        out = x + y
 
        return out
 
    def backward(self, dout):
        dx = dout * 1
        dy = dout * 1
 
        return dx, dy
 
cs

덧셈 노드에서는 초기화가 필요 없어 __init__()에서 아무 일도 일어나지 않습니다. 덧셈 노드에서 forward()에서는 입력받은 두 인수 x, y를 더해서 반환합니다. backward()에서는 상류에서 내려온 미분(dout)을 그대로 하류로 흘릴뿐입니다.

곱셈 노드

곱셈 노드의 수식은 아래와 같습니다.

z=f(x,y)=xy

곱셈 노드의 로컬 그래디언트는 아래와 같습니다.

zx=(xy)x=yzy=(xy)y=x

곱셈 노드의 계산그래프는 아래와 같습니다. 현재 입력값에 대한 Loss의 변화량은 로컬 그래디언트에 흘러들어온 그래디언트를 각각 곱해주면 됩니다. 곱셈 노드의 역전파는 순전파 때 입력 신호들을 서로 바꾼 값을 곱해서 하류로 흘려보내는 걸 확인할 수 있습니다.

곱셈 노드를 파이썬으로 구현해 보겠습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class MulLayer:
    def __init__(self):
        self.x = None
        self.y = None
 
    def forward(self, x, y):
        self.x = x
        self.y = y                
        out = x * y
 
        return out
 
    def backward(self, dout):
        dx = dout * self.y  # x와 y를 바꾼다.
        dy = dout * self.x
 
        return dx, dy
cs
__ init__()에서는 인스턴스 변수인 x와 y를 초기화합니다. 이 두변수는 순전파 시의 입력 값을 유지하기 위해서 사용합니다.


아래 코드는 덧셈 노드와 곱셈 노드의 계산 그래프를 파이썬으로 구현한 것입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# coding: utf-8
from layer_naive import *
 
apple = 100
apple_num = 2
orange = 150
orange_num = 3
tax = 1.1
 
# layer
mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()
 
# forward
apple_price = mul_apple_layer.forward(apple, apple_num)  # (1)
orange_price = mul_orange_layer.forward(orange, orange_num)  # (2)
all_price = add_apple_orange_layer.forward(apple_price, orange_price)  # (3)
price = mul_tax_layer.forward(all_price, tax)  # (4)
 
# backward
dprice = 1
dall_price, dtax = mul_tax_layer.backward(dprice)  # (4)
dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price)  # (3)
dorange, dorange_num = mul_orange_layer.backward(dorange_price)  # (2)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)  # (1)
 
print("price:"int(price))
print("dApple:", dapple)
print("dApple_num:"int(dapple_num))
print("dOrange:", dorange)
print("dOrange_num:"int(dorange_num))
print("dTax:", dtax)
 
cs

여기서 backward()가 받는 인수는 순전파의 출력에 대한 미분임에 주의하세요.


활성화 함수

ReLU 노드

활성화함수(activation function)로 사용되는 ReLU는 다음 식처럼 정의됩니다.

y=x(x>0)y=0(x0)

ReLU 노드의 로컬 그래디언트는 아래와 같습니다.

yx=1(x>0)yx=0(x0)

위 식을 보면 순전파 때의 입력인 x가 0보다 크면 역전파는 상류의 값을 그대로 하류로 흘립니다. 반면, 순전파 때 x가 0이하면 역전파 때는 하류로 신호를 흘려보내지 않습니다.(0을 보냄)

계산그래프는 아래와 같습니다.

ReLU계층을 구현해 보겠습니다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
from common.functions import *
from common.util import im2col, col2im
 
 
class Relu:
    def __init__(self):
        self.mask = None
 
    def forward(self, x):
        self.mask = (x <= 0)
        out = x.copy()
        out[self.mask] = 0
 
        return out
 
    def backward(self, dout):
        dout[self.mask] = 0
        dx = dout
 
        return dx
cs

ReLU클래스는 mask라는 인스턴스 변수를 가집니다. mask는 True/False로 구성된 넘파이 배열로, 순전파의 입력인 x의 원소 값이 0이하인 인덱스는 True, 그 외(0보다 큰 원소)는 False로 유지합니다. 

Sigmoid 노드

시그모이드(sigmoid) 함수는 아래와 같이 정의됩니다.

y=11+exp(x)

시그모이드 노드의 로컬 그래디언트는 다음과 같습니다.

yx=y(1y)

계산그래프는 아래와 같습니다.

시그모이드 계층을 파이썬으로 구현한 것입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
class Sigmoid:
    def __init__(self):
        self.out = None
 
    def forward(self, x):
        out = sigmoid(x)
        self.out = out
        return out
 
    def backward(self, dout):
        dx = dout * (1.0 - self.out) * self.out
 
        return dx
cs

이 구현에서는 순전파의 출력을 인스턴스 변수out에 보관했다가, 역전파 계산 때 그 값을 사용합니다.



하이퍼볼릭탄젠트 노드

하이퍼볼릭탄젠트 노드 y=tanh(x)의 로컬 그래디언트는 다음과 같습니다.

yx=1y2

계산그래프는 아래와 같습니다.

Hadamard product 노드

Hadamard product란 요소별 곱셈을 뜻합니다. 기호로는  등을 씁니다. 예컨대 아래와 같습니다.

[132][453]=[4156]

두 벡터에 Hadamard product 연산을 적용했을 때, 그 로컬 그래디언트는 아래와 같습니다.

Hadamard product 노드 또한 다른 노드와 마찬가지로 위 로컬 그래디언트에 흘러들어온 그래디언트를 내적(inner product)해서 현시점의 그래디언트를 계산합니다. 그런데 흘러들어온 그래디언트 또한 벡터일 경우 Hadamard product 노드 로컬 그래디언트의 대각성분(위 그림에서 ht11,…,ht1n)과 요소별 곱셈을 하여도 같은 결과가 나옵니다.

벡터, 행렬로의 확장

지금까지 말씀드린 역전파는 기본적으로 스칼라를 대상으로 한 편미분과 역전파였습니다. 하지만 여기에 적용된 원칙들은 벡터, 행렬에도 적용할 수 있습니다. 해당 변수에 대한 그래디언트는 해당 변수의 차원 수와 일치해야 한다는 원칙을 기억하고 있으면 됩니다. 이와 관련 cs231n의 한 단락을 정리 용도로 캡처해 놨습니다.



-------------------------------------------
이 글은 밑바닥부터 시작하는 딥러닝과 
미국 스탠포드대학의 CS231n 강의를 듣고 정리한 글입니다.


'AI > 밑바닥부터 시작하는 딥러닝' 카테고리의 다른 글

매개변수 갱신  (0) 2018.09.27
오차역전파(2)  (0) 2018.09.18
신경망 학습  (0) 2018.09.13
신경망(2)  (0) 2018.09.09
신경망(1)  (0) 2018.09.03
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함