# -*- coding: utf-8 -*-
import tensorflow as tf
# 선형회귀 모델(Wx + b)을 위한 tf.Variable을 선언
W = tf.Variable(tf.random.normal(shape=[1]))
b = tf.Variable(tf.random.normal(shape=[1]))
@tf.function
def linear_model(x):
return W*x + b
# 손실 함수를 정의
# MSE 손실함수 \mean{(y' - y)^2}
@tf.function
def mse_loss(y_pred, y):
return tf.reduce_mean(tf.square(y_pred - y))
# 최적화를 위한 그라디언트 디센트 옵티마이저를 정의
optimizer = tf.optimizers.SGD(0.01)
# 최적화를 위한 function을 정의
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
y_pred = linear_model(x)
loss = mse_loss(y_pred, y)
gradients = tape.gradient(loss, [W, b])
optimizer.apply_gradients(zip(gradients, [W, b]))
# 트레이닝을 위한 입력값과 출력값을 준비
x_train = [1, 2, 3, 4]
y_train = [2, 4, 6, 8]
# 경사하강법을 1000번 수행
for i in range(1000):
train_step(x_train, y_train)
# 테스트를 위한 입력값을 준비
x_test = [3.5, 5, 5.5, 6]
# 테스트 데이터를 이용해 학습된 선형회귀 모델이 데이터의 경향성(y=2x)을 잘 학습했는지 측정
# 예상되는 참값 : [7, 10, 11, 12]
print(linear_model(x_test).numpy())
1-8 = 가설 정의
파라미터 W,b 정의
tf.Variable = api
random.normal = 가우시안 distribution에서 random 값을 뽑음
shape 인자 값 = 지정하고자 하는 모델의 파라미터 shape. 선형 회귀에선 하나의 데이터(x)를 받아 하나를 도출해내는 것이기 때문에 1차원으로 지정함
9-11 =Linear Regression 함수 작성
-----------------------------------------------------------------------------------------------
13 - 17 = 손실함수 정의
loss(예측값, 정답값)
square = 제곱함수
reduce_mean = 평균함수
-----------------------------------------------------------------------------------------------
19 - 29 = optimization 정의, gradient descent 정의
SGD : mini-batch gradient descent 실행해 주는 기본적인 옵티마이저
0.01 = learning rate
train_step : gradient descent를 한 단계 실행해주는 함수
y_pred : linear regression에 기반한 예측값
mse : mean_sqaured_error
gradients : loss에 대한 해당 파라미터(W,b)에 대한 gradient 값
zip : 계산한 gradient와 갱신 대상 파라미터(W,b)를 묶어줌 ( zip([1,2,3],[4,5,6]) ==> (1,4), (2,5), (3,6) )
딥러닝 알고리즘의 "학습" = apply_gradient 함수를 반복 호출해 랜덤 파라미터를 학습 방향성에 맞게 변형 시켜주는 것!
'Study > Deep Learning' 카테고리의 다른 글
다층 퍼셉트론 MLP (0) | 2021.10.15 |
---|---|
TensorFlow 2.0과 Softmax Regression을 이용한 MNIST 숫자분류기 구현 (0) | 2021.10.15 |
TensorFlow (0) | 2021.10.15 |
다양한 Computer Vision 문제 영역 (0) | 2021.10.14 |
머신러닝 Data 종류 (0) | 2021.10.14 |