|
| 1 | +from tensorflow.keras import datasets, models, layers, losses |
| 2 | +import tensorflow as tf |
| 3 | +import math |
| 4 | +import numpy as np |
| 5 | +import random |
| 6 | +import matplotlib.pyplot as plt |
| 7 | + |
| 8 | +number = 100 |
| 9 | + |
| 10 | +# randomPoints on the sin graph |
| 11 | +# number - number of points |
| 12 | +# angles - tuple for range of angles |
| 13 | +def randomPoints(number, angles= (-360,360)): |
| 14 | + angleList = [] |
| 15 | + sinList = [] |
| 16 | + while(number>0): |
| 17 | + angle = random.uniform(angles[0],angles[1]) |
| 18 | + angle = angle * math.pi / 180 # convert to radians |
| 19 | + angleList.append([angle]) |
| 20 | + sinList.append([math.sin(angle)]) |
| 21 | + number = number -1 |
| 22 | + return angleList, sinList |
| 23 | + |
| 24 | +# only specific points on the sin graph |
| 25 | +# number - number of points |
| 26 | +# angles - target angle |
| 27 | +def selectPoints(number, angle=30): |
| 28 | + angleList = [] |
| 29 | + sinList = [] |
| 30 | + while(number>0): |
| 31 | + radangle = angle * math.pi / 180 # convert to radians |
| 32 | + angleList.append([radangle]) |
| 33 | + sinList.append([math.sin(radangle)]) |
| 34 | + angleList.append([-radangle]) |
| 35 | + sinList.append([math.sin(-radangle)]) |
| 36 | + angle = angle + 60 |
| 37 | + number = number -1 |
| 38 | + return angleList, sinList |
| 39 | + |
| 40 | +angles = (-360,360) |
| 41 | +angleList, sinList = randomPoints(10000, angles) |
| 42 | +#angleList, sinList = selectPoints(10000) |
| 43 | + |
| 44 | +model = models.Sequential() |
| 45 | +model.add(layers.Dense(10, activation='tanh', input_shape=(1,))) |
| 46 | +model.add(layers.Dense(1, activation=None)) |
| 47 | +model.compile(optimizer='Adam', |
| 48 | + loss=losses.MeanSquaredError(), |
| 49 | + metrics=['mean_squared_error']) |
| 50 | +history = model.fit(np.array(angleList),np.array(sinList), epochs=200) |
| 51 | +plt.plot(history.history['mean_squared_error'], label='mean_squared_error') |
| 52 | +plt.xlabel('Epoch') |
| 53 | +plt.ylabel('MSE') |
| 54 | +plt.ylim([0.0, 0.2]) |
| 55 | +plt.legend(loc='lower right') |
| 56 | + |
| 57 | +angleTest, sinTest = randomPoints(10) |
| 58 | +print(model.predict(np.array(angleTest))) |
| 59 | +print(sinTest) |
| 60 | + |
| 61 | +plt.show() |
0 commit comments