# -*- coding: utf-8 -*-

import tensorflow as tf

# 선형회귀 모델(Wx + b)을 위한 tf.Variable을 선언
W = tf.Variable(tf.random.normal(shape=[1]))
b = tf.Variable(tf.random.normal(shape=[1]))

@tf.function
def linear_model(x):
  return W*x + b

# 손실 함수를 정의
# MSE 손실함수 \mean{(y' - y)^2}
@tf.function
def mse_loss(y_pred, y):
  return tf.reduce_mean(tf.square(y_pred - y))

# 최적화를 위한 그라디언트 디센트 옵티마이저를 정의
optimizer = tf.optimizers.SGD(0.01)

# 최적화를 위한 function을 정의
@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
    y_pred = linear_model(x)
    loss = mse_loss(y_pred, y)
  gradients = tape.gradient(loss, [W, b])
  optimizer.apply_gradients(zip(gradients, [W, b]))

# 트레이닝을 위한 입력값과 출력값을 준비
x_train = [1, 2, 3, 4]
y_train = [2, 4, 6, 8]

# 경사하강법을 1000번 수행
for i in range(1000):
  train_step(x_train, y_train)

# 테스트를 위한 입력값을 준비
x_test = [3.5, 5, 5.5, 6]
# 테스트 데이터를 이용해 학습된 선형회귀 모델이 데이터의 경향성(y=2x)을 잘 학습했는지 측정
# 예상되는 참값 : [7, 10, 11, 12]
print(linear_model(x_test).numpy())

 

1-8 = 가설 정의

파라미터 W,b 정의

tf.Variable = api

random.normal = 가우시안 distribution에서 random 값을 뽑음

shape 인자 값 = 지정하고자 하는 모델의 파라미터 shape. 선형 회귀에선 하나의 데이터(x)를 받아 하나를 도출해내는 것이기 때문에 1차원으로 지정함

 

9-11 =Linear Regression 함수 작성

-----------------------------------------------------------------------------------------------

13 - 17 = 손실함수 정의

loss(예측값, 정답값)

square = 제곱함수

reduce_mean = 평균함수

 

-----------------------------------------------------------------------------------------------

19 - 29 = optimization 정의, gradient descent 정의

SGD : mini-batch gradient descent 실행해 주는 기본적인 옵티마이저

0.01 = learning rate

 

train_step : gradient descent를 한 단계 실행해주는 함수

y_pred : linear regression에 기반한 예측값

mse : mean_sqaured_error

gradients : loss에 대한 해당 파라미터(W,b)에 대한 gradient 값

zip : 계산한 gradient와 갱신 대상 파라미터(W,b)를 묶어줌 (   zip([1,2,3],[4,5,6]) ==> (1,4), (2,5), (3,6)   )

 

 

딥러닝 알고리즘의 "학습" = apply_gradient 함수를 반복 호출해 랜덤 파라미터를 학습 방향성에 맞게 변형 시켜주는 것!

 

실행 결과

 

 

'Study > Deep Learning' 카테고리의 다른 글

다층 퍼셉트론 MLP  (0) 2021.10.15
TensorFlow 2.0과 Softmax Regression을 이용한 MNIST 숫자분류기 구현  (0) 2021.10.15
TensorFlow  (0) 2021.10.15
다양한 Computer Vision 문제 영역  (0) 2021.10.14
머신러닝 Data 종류  (0) 2021.10.14

+ Recent posts