Skip to content

Commit d85f768

Browse files
committed
NN to solve sin function
1 parent b69685b commit d85f768

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

Summer20/NeuralNetwork/tfsin.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from tensorflow.keras import datasets, models, layers, losses
2+
import tensorflow as tf
3+
import math
4+
import numpy as np
5+
import random
6+
import matplotlib.pyplot as plt
7+
8+
number = 100
9+
10+
# randomPoints on the sin graph
11+
# number - number of points
12+
# angles - tuple for range of angles
13+
def randomPoints(number, angles= (-360,360)):
14+
angleList = []
15+
sinList = []
16+
while(number>0):
17+
angle = random.uniform(angles[0],angles[1])
18+
angle = angle * math.pi / 180 # convert to radians
19+
angleList.append([angle])
20+
sinList.append([math.sin(angle)])
21+
number = number -1
22+
return angleList, sinList
23+
24+
# only specific points on the sin graph
25+
# number - number of points
26+
# angles - target angle
27+
def selectPoints(number, angle=30):
28+
angleList = []
29+
sinList = []
30+
while(number>0):
31+
radangle = angle * math.pi / 180 # convert to radians
32+
angleList.append([radangle])
33+
sinList.append([math.sin(radangle)])
34+
angleList.append([-radangle])
35+
sinList.append([math.sin(-radangle)])
36+
angle = angle + 60
37+
number = number -1
38+
return angleList, sinList
39+
40+
angles = (-360,360)
41+
angleList, sinList = randomPoints(10000, angles)
42+
#angleList, sinList = selectPoints(10000)
43+
44+
model = models.Sequential()
45+
model.add(layers.Dense(10, activation='tanh', input_shape=(1,)))
46+
model.add(layers.Dense(1, activation=None))
47+
model.compile(optimizer='Adam',
48+
loss=losses.MeanSquaredError(),
49+
metrics=['mean_squared_error'])
50+
history = model.fit(np.array(angleList),np.array(sinList), epochs=200)
51+
plt.plot(history.history['mean_squared_error'], label='mean_squared_error')
52+
plt.xlabel('Epoch')
53+
plt.ylabel('MSE')
54+
plt.ylim([0.0, 0.2])
55+
plt.legend(loc='lower right')
56+
57+
angleTest, sinTest = randomPoints(10)
58+
print(model.predict(np.array(angleTest)))
59+
print(sinTest)
60+
61+
plt.show()

0 commit comments

Comments
 (0)