How can you create a simple model from scratch in tensorflow without using higher API’s like keras?

https://en.wikipedia.org/wiki/Linear_regression

Imports

Let’s start by importing the main libraries that we will be requiring. And that’s just tensorflow for the model and numpy for generating our data and matplotlib to visualise our model in action.

import tensorflow as tf 
import numpy as np
import matplotlib.pyplot as plt

Data Creation

Now let’s create the data:

X = tf.constant(np.linspace(0, 2, 2000), dtype=tf.float32)
Y = X * tf.exp(-X**2) #finding exponential
plt.plot(X, Y)
plt.show()
def make_features(X):
f1 = tf.ones_like(X) # Bias.
f2 = X
f3 = tf.square(X)
f4 = tf.sqrt(X)
f5 = tf.exp(X)
return tf.stack([f1, f2, f3, f4, f5], axis=1)
def predict(X, W):
return tf.squeeze(X @ W, -1)
def loss_mse(X, Y, W):
Y_hat = predict(X, W)
errors = (Y_hat - Y)**2
return tf.reduce_mean(errors)

Gradient Function

def compute_gradients(X, Y, W):
with tf.GradientTape() as tape:
loss = loss_mse(Xf, Y, W)
return tape.gradient(loss, W)
with tf.GradientTape() as tape:
loss = # computation
gradients = tape.gradient(loss, [w0, w1])

Model Training

STEPS = 2000 #try 50000
LEARNING_RATE = .02
Xf = make_features(X)
n_weights = Xf.shape[1]
W = tf.Variable(np.zeros((n_weights, 1)), dtype=tf.float32)# For plotting
steps, losses = [], []
plt.figure()
for step in range(1, STEPS + 1):
dW = compute_gradients(X, Y, W)
W.assign_sub(dW * LEARNING_RATE)
if step % 100 == 0:
loss = loss_mse(Xf, Y, W)
steps.append(step)
losses.append(loss)
plt.clf()
plt.plot(steps, losses)
plt.show()
print("STEP: {} MSE: {}".format(STEPS, loss_mse(Xf, Y, W)))

Testing

Let’s now evaluate how our model actually performs.

# The .figure() method will create a new figure, or activate an existing figure.
plt.figure()
# The .plot() is a versatile function, and will take an arbitrary number of arguments. For example, to plot x versus y.
plt.plot(X, Y, label='actual')
plt.plot(X, predict(Xf, W), label='predicted')
# The .legend() method will place a legend on the axes.
plt.legend()
plt.show()

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store