In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
In [2]:
train_X = np.float32(np.random.normal(scale=2, size=(800, 69)))
train_y = np.float32(np.random.normal(scale=0.1, size=(800,1)))
In [3]:
class LinearRegression(nn.Module):
def __init__(self, inputs, targets, learning_rate=1e-4):
super(LinearRegression, self).__init__()
self._train_X = inputs
self._train_y = targets
self._train_X_size = inputs.shape[1]
self._train_y_size = targets.shape[1]
self._learning_rate = learning_rate
self._linear = nn.Linear(self._train_X_size, self._train_y_size)
# Loss and Optimizer
self._loss_function = nn.MSELoss()
self._optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
def fit(self, training_epochs= 1e3, display= 1e2):
display = np.int(display)
for epoch in np.arange(np.int(training_epochs)):
inputs = Variable(torch.from_numpy(self._train_X))
targets = Variable(torch.from_numpy(self._train_y))
self._optimizer.zero_grad() #清空所有被优化过的Variable的梯度.
outputs = self._linear(inputs) # 使用神经网络架构前向推断
self._loss = self._loss_function(outputs, targets) # 计算批次损失函数
self._loss.backward() # 误差反向传播
self._optimizer.step()
if (epoch+1) % display == 0:
print ('Epoch (%d/%d), loss:%.4f' %(epoch+1, training_epochs, self._loss.data[0]))
def pred(self, X):
return self._linear(Variable(torch.from_numpy(X))).data.numpy()
In [4]:
a = LinearRegression(train_X, train_y)
a.fit()
Epoch (100/1000), loss:1.3313
Epoch (200/1000), loss:1.1228
Epoch (300/1000), loss:0.9495
Epoch (400/1000), loss:0.8052
Epoch (500/1000), loss:0.6846
Epoch (600/1000), loss:0.5837
Epoch (700/1000), loss:0.4991
Epoch (800/1000), loss:0.4280
Epoch (900/1000), loss:0.3680
Epoch (1000/1000), loss:0.3173
In [5]:
a.pred(train_X)
Out[5]:
array([[-0.46727881],
[-0.2737641 ],
[ 1.05788863],
[-0.40368289],
[ 0.9073351 ],
[-0.33162621],
[-0.18489531],
[ 0.17736974],
[ 0.94153154],
[-1.37812138],
[ 1.08444297],
[ 0.56933361],
[-0.93551493],
[-0.58512652],
[-0.0546658 ],
[ 0.54009169],
[-0.01503232],
[ 0.7445389 ],
[-0.51376849],
[-0.39794806],
[-0.42978182],
[ 0.11550209],
[ 0.16118822],
[ 0.89003134],
[ 0.08549362],
[-0.12640905],
[ 0.42096788],
[ 0.68774188],
[ 0.24705368],
[-0.38662568],
[-0.009137 ],
[ 0.06563276],
[ 0.31275877],
[ 0.86082524],
[-0.63355726],
[ 0.58423811],
[ 0.59983546],
[-0.0220359 ],
[ 0.63871819],
[-0.67871124],
[ 0.07216719],
[-0.83627725],
[ 1.23007071],
[ 0.17995572],
[ 0.82321936],
[-0.23555687],
[ 0.247641 ],
[-0.47473851],
[ 0.07626043],
[-1.07636189],
[-0.43494901],
[ 0.36791253],
[ 0.11122737],
[-0.30446425],
[ 0.68509686],
[ 0.08989926],
[ 0.04426361],
[ 0.14791211],
[-0.56430852],
[-0.36163244],
[-0.1383675 ],
[-0.36726722],
[-0.17415518],
[ 0.70036316],
[-0.62004799],
[-0.23066962],
[-0.71386451],
[ 0.16563061],
[-0.05870824],
[ 0.53666085],
[-0.26176867],
[-0.33016458],
[ 0.34658405],
[ 0.60285509],
[ 0.35613638],
[-0.58566266],
[-0.67799783],
[ 0.00185735],
[-0.55233228],
[ 0.98762286],
[-0.21431956],
[ 0.5494563 ],
[-0.21009275],
[ 0.69777751],
[-0.03421655],
[ 0.37913439],
[ 0.18968117],
[-1.27394319],
[-0.77346385],
[-0.46639839],
[ 0.23492968],
[ 0.9172771 ],
[ 0.15780166],
[ 1.55143464],
[-0.09927257],
[-1.14908171],
[ 0.28747687],
[ 0.31212932],
[-1.04667866],
[ 0.16796905],
[ 0.45362186],
[ 0.62833196],
[ 0.53346765],
[ 0.32354438],
[-0.37779015],
[-0.61570638],
[-0.05419575],
[ 0.69903791],
[ 0.59337538],
[ 0.77665079],
[ 0.11431374],
[-0.13102683],
[ 0.6272881 ],
[ 0.47650045],
[ 1.31695855],
[-0.32094696],
[-0.4573687 ],
[ 1.04163384],
[ 0.9533608 ],
[ 0.47069952],
[-0.5432148 ],
[ 0.03210211],
[ 0.22089702],
[ 1.28343439],
[-0.25111344],
[ 0.22281212],
[-1.10027289],
[ 1.05092216],
[-0.12315021],
[ 1.23107111],
[-0.12206171],
[ 0.79479825],
[-0.27658173],
[-0.03257161],
[-0.23320147],
[ 0.02830369],
[ 0.4720974 ],
[ 0.96951437],
[ 0.93682319],
[ 1.13310838],
[ 0.29869992],
[ 0.09427438],
[ 0.88912523],
[ 0.42220753],
[ 0.0372902 ],
[-1.0529213 ],
[-0.40246773],
[ 0.4762809 ],
[-0.68440276],
[-0.84210074],
[-0.33824477],
[-0.56899309],
[-0.09058844],
[-0.40620026],
[ 0.95020068],
[-0.88972116],
[-0.18232229],
[-0.29647139],
[ 0.37699977],
[ 0.31088534],
[-0.55553824],
[ 0.73128766],
[-0.31054783],
[ 0.24725953],
[-0.17362219],
[ 0.3863723 ],
[ 0.2109144 ],
[-0.18588248],
[ 1.01030123],
[ 0.4374027 ],
[-0.31382534],
[-1.11601281],
[ 1.06045115],
[-0.70684767],
[ 0.79756087],
[-0.27063611],
[ 0.02249322],
[ 0.82708502],
[-0.33988816],
[ 1.57267392],
[ 0.50086051],
[-0.17204627],
[-0.4206275 ],
[ 0.20145682],
[ 0.7343356 ],
[ 0.71742195],
[-0.03619666],
[ 0.33785543],
[ 0.67208284],
[-0.84039855],
[-0.22263804],
[ 0.17220879],
[ 0.30860427],
[-0.29043838],
[ 0.18191525],
[-0.21931702],
[ 0.07588533],
[ 0.49802545],
[ 0.33642587],
[ 0.38563755],
[ 0.63775098],
[-0.40785185],
[ 0.32480064],
[ 0.43162766],
[ 0.05904747],
[-1.50012064],
[ 0.67142713],
[ 0.58435488],
[ 0.31153658],
[ 0.35355875],
[-0.80490679],
[ 0.35498881],
[ 0.74034828],
[ 0.57115614],
[ 0.23331717],
[ 0.76214534],
[-0.84203553],
[ 1.03372097],
[-0.46389315],
[ 0.13253121],
[-0.1342862 ],
[-0.37817308],
[ 0.23515275],
[ 0.16335416],
[ 0.32894057],
[ 1.26682103],
[-1.42864013],
[ 0.17241034],
[-0.10295772],
[ 0.50833976],
[ 1.07998681],
[ 0.33910474],
[ 0.27489933],
[-0.37901121],
[-0.44782141],
[-0.15007547],
[-0.27165475],
[ 0.18508086],
[-0.18758163],
[ 0.27747795],
[ 0.2698271 ],
[ 0.44019037],
[-0.09881674],
[-0.52289158],
[-0.58016407],
[-0.55769914],
[ 0.35220537],
[-0.71809584],
[ 0.19504184],
[ 0.32061931],
[ 0.17160225],
[ 0.13654795],
[ 0.71740389],
[ 0.30679858],
[ 0.20361051],
[-0.07986293],
[ 1.49066842],
[-0.62082833],
[-0.18991417],
[ 0.2118383 ],
[-0.79518676],
[ 0.65260118],
[ 0.16324854],
[-0.74143744],
[-0.32836285],
[-0.03312915],
[ 0.04903696],
[ 0.46051851],
[-0.94917846],
[-0.30457327],
[ 0.47191548],
[-0.56378162],
[ 0.28875974],
[-0.14943188],
[ 0.22196722],
[-0.29553759],
[ 0.23299366],
[ 0.06777838],
[ 0.28480262],
[ 0.40217543],
[-0.08229988],
[-0.56035197],
[ 0.44685695],
[ 0.26760721],
[-0.4480128 ],
[ 0.36270693],
[ 0.39184597],
[ 1.2296629 ],
[ 0.27685833],
[-0.29883829],
[ 0.87571359],
[ 0.11941351],
[ 0.45306501],
[-0.49373791],
[-0.01900671],
[-0.71475434],
[-1.08297241],
[ 0.38645425],
[-0.11002251],
[ 0.08480822],
[ 0.60551834],
[ 0.68000633],
[-0.37446484],
[ 0.28558043],
[ 0.14444607],
[ 0.74918902],
[-0.70037574],
[-0.19860438],
[-0.18492687],
[-0.2467089 ],
[-0.2727268 ],
[ 0.46970311],
[ 1.63833928],
[ 0.01708406],
[ 0.07640977],
[ 0.01324297],
[-0.74184459],
[ 0.02552059],
[-0.05239459],
[-0.09579279],
[-1.18469524],
[ 1.11568236],
[-0.37881571],
[-0.05398729],
[-0.13281024],
[-0.20784631],
[-0.63487339],
[ 0.11241669],
[ 0.38291845],
[-0.56777012],
[ 0.47283524],
[-0.57235891],
[ 0.62945306],
[ 0.78239769],
[ 0.25560462],
[-0.06890624],
[-0.09904327],
[ 0.72309995],
[ 0.27688411],
[-0.01648499],
[-0.2638638 ],
[ 0.04626647],
[ 0.2585853 ],
[-0.08818135],
[-0.67361593],
[-0.33928898],
[-0.46787903],
[ 0.71672267],
[ 0.19206578],
[ 0.36239332],
[ 0.20416805],
[-0.08599546],
[-0.82836252],
[ 0.11237525],
[ 0.92147261],
[ 0.614411 ],
[-0.34799397],
[ 0.68631375],
[ 0.93303359],
[ 0.94622296],
[ 0.32473505],
[ 0.12558074],
[-0.75656539],
[-0.43628553],
[ 0.22439614],
[ 0.81300449],
[-0.24929121],
[ 0.67036402],
[-0.34255323],
[ 0.33286998],
[-0.22178271],
[ 0.27287039],
[ 0.31916785],
[-0.38992885],
[-0.44305053],
[-0.31106797],
[ 0.61986959],
[-0.25370497],
[-0.12899595],
[-0.42957893],
[ 0.58483469],
[-0.46702805],
[-0.49193183],
[-0.13104644],
[-0.39723346],
[ 1.52816248],
[-0.78577769],
[ 0.46633014],
[ 0.04623763],
[-0.24196044],
[-0.10803916],
[-0.2881698 ],
[ 0.63831007],
[-0.82407862],
[-0.01170271],
[ 0.69322157],
[-0.5017786 ],
[ 1.18716002],
[ 0.39446518],
[-0.24519089],
[ 0.74265879],
[-0.31898609],
[-0.12841737],
[ 0.53461844],
[ 0.08317027],
[ 1.40966749],
[ 0.17753935],
[ 0.51147407],
[-0.99975133],
[ 0.12488622],
[ 0.18601325],
[ 0.66738188],
[ 0.42723215],
[-0.25254002],
[ 0.07515737],
[ 1.0625335 ],
[-0.6224317 ],
[-0.04716713],
[ 0.73639989],
[-0.07007905],
[-0.08152183],
[-0.26749712],
[ 0.7328738 ],
[ 0.63918918],
[-0.41725296],
[-0.04829331],
[ 0.20069724],
[-0.96875739],
[ 0.02582369],
[ 0.87910485],
[ 0.65803856],
[ 0.22482005],
[-0.17080265],
[ 0.60706478],
[-0.58735651],
[-0.3689813 ],
[ 0.02913303],
[-0.19729325],
[ 0.05154093],
[ 1.3757931 ],
[-0.29170203],
[ 0.10140788],
[-0.84000373],
[ 0.40076745],
[ 0.4035421 ],
[-0.84458983],
[-0.13895413],
[-0.40040359],
[ 0.0197158 ],
[-0.52464128],
[ 0.26832265],
[ 0.17876062],
[ 0.06820909],
[ 0.99942982],
[ 0.36774647],
[ 0.28691959],
[-0.69095826],
[ 0.31019875],
[-0.19261858],
[-0.54512566],
[-0.97174621],
[ 0.32773006],
[ 0.72121447],
[ 0.42360416],
[-0.70742697],
[ 0.49526089],
[ 1.51969481],
[-0.14510942],
[ 0.47062743],
[-0.0805038 ],
[-0.62218893],
[-0.15304625],
[-0.18709144],
[ 0.29036117],
[-0.35462111],
[ 0.79737014],
[ 1.50713539],
[ 0.14232478],
[ 0.30431071],
[-0.35210389],
[-0.0099313 ],
[ 0.08750501],
[ 0.52289563],
[-0.21382996],
[ 0.86614323],
[-0.3611007 ],
[ 0.66799498],
[-0.0742574 ],
[ 0.42857921],
[-0.37529054],
[-0.54469341],
[-0.44745681],
[-0.94318914],
[-0.43181208],
[-0.68224454],
[ 0.05364816],
[ 0.48492202],
[-0.26673377],
[-0.65304083],
[ 0.10603052],
[ 0.46871355],
[-0.79412782],
[ 0.07105336],
[-0.51376021],
[-0.3923685 ],
[ 0.27739224],
[ 0.28864709],
[-0.17036608],
[-1.16720772],
[ 0.06956331],
[-0.00730576],
[-0.71766394],
[ 0.28128678],
[ 0.1672478 ],
[-0.02370133],
[ 0.50171041],
[-0.07009865],
[ 0.01061276],
[ 0.05633564],
[-0.49955001],
[-0.2817063 ],
[-0.19101709],
[-0.0033436 ],
[ 1.12950516],
[-0.5755657 ],
[ 0.21608806],
[-0.8256532 ],
[ 0.9520039 ],
[ 0.80697286],
[-0.07756127],
[-0.3151671 ],
[ 0.14263672],
[ 0.01480057],
[-0.11259148],
[ 0.90663183],
[ 0.21596032],
[ 0.98260522],
[-0.49363783],
[ 0.12438676],
[ 0.06250332],
[-0.43569633],
[ 0.25145367],
[ 0.35966235],
[-0.70552313],
[ 0.17478001],
[ 0.10473324],
[-0.16089395],
[-0.60305721],
[ 0.19422024],
[ 0.09758363],
[-0.43904236],
[-0.54022515],
[ 0.73896372],
[ 0.6099056 ],
[-0.00385395],
[ 0.43497452],
[ 0.38447818],
[-0.80793869],
[ 0.73171014],
[-0.60187632],
[-0.44840297],
[-1.0055145 ],
[-0.4168795 ],
[-0.37266061],
[ 0.20816579],
[-0.14508697],
[ 0.13987735],
[-0.72863972],
[-0.05385435],
[-0.58654815],
[-0.10666194],
[ 0.22793192],
[-0.38269076],
[ 0.87929487],
[ 0.83478677],
[ 0.51511145],
[-0.23796147],
[ 1.36834955],
[ 0.08172721],
[-0.33726388],
[ 0.55346268],
[-0.42647991],
[ 0.38795599],
[ 0.56526667],
[-0.45462832],
[ 0.98716998],
[-0.30336359],
[ 0.37951115],
[ 0.38037071],
[ 0.10951657],
[-0.2639437 ],
[ 0.48513916],
[ 0.46821287],
[-0.28317851],
[ 0.4840022 ],
[ 1.04958975],
[ 0.10377023],
[ 0.10505842],
[-0.02940079],
[-0.42901555],
[ 0.78063476],
[ 0.45006946],
[ 1.0737983 ],
[ 0.52802682],
[ 0.18564761],
[ 1.72601175],
[ 0.06325509],
[ 0.3689467 ],
[-0.8282966 ],
[-1.15416896],
[-0.32920754],
[-0.89659929],
[ 1.05815804],
[ 0.70554912],
[ 0.80964494],
[ 0.06259336],
[ 0.01603659],
[ 0.32098967],
[ 0.14421856],
[ 0.0266166 ],
[ 1.07772565],
[-0.51597631],
[ 0.74300241],
[-0.33208385],
[ 0.45798764],
[ 0.11365047],
[ 0.5227387 ],
[ 0.79766619],
[-0.24054846],
[ 0.2758162 ],
[-0.09917698],
[-0.20623481],
[-0.04829779],
[-0.01329603],
[ 0.47950909],
[ 1.02781618],
[ 0.69702411],
[ 0.26320827],
[-0.0665734 ],
[ 0.13802579],
[ 0.40244782],
[ 1.01526988],
[-0.21911451],
[-0.38922921],
[ 0.54782462],
[ 0.3258568 ],
[ 0.33680406],
[ 0.87337738],
[-0.03875498],
[ 1.00650442],
[ 0.85590285],
[-0.25818923],
[-0.01531341],
[-0.61894882],
[ 0.69424355],
[-0.69726413],
[-0.01645393],
[-0.07206679],
[-0.16152772],
[ 0.22644481],
[-0.20973408],
[ 1.01885641],
[ 1.05397451],
[ 0.32664588],
[ 0.40778533],
[ 0.78050768],
[-0.16207248],
[-0.02656209],
[ 0.60948062],
[ 0.78253442],
[-0.27905527],
[ 0.85446668],
[ 0.29925945],
[-0.00957713],
[-0.26511827],
[-0.23545912],
[ 0.70573199],
[ 0.40559021],
[ 0.11790828],
[-0.20682126],
[ 0.45438844],
[ 0.82508808],
[-0.4916046 ],
[-0.04704569],
[-0.38345656],
[-0.07053103],
[-0.36543491],
[ 1.01872051],
[ 0.43383664],
[-0.67635405],
[-0.20012966],
[ 0.05783844],
[ 0.20844081],
[ 0.73643136],
[ 0.52070045],
[-0.43074551],
[ 0.51860076],
[ 0.39868322],
[-0.07907497],
[ 0.31996629],
[ 0.66879386],
[-0.12571037],
[ 0.32119367],
[ 0.96006119],
[ 0.40066347],
[-0.00539017],
[-0.15629715],
[ 0.96241993],
[-0.19905522],
[ 0.85837513],
[ 0.80206525],
[ 0.97351396],
[-0.24246326],
[-0.48889527],
[-0.41130775],
[ 0.15380764],
[ 0.8356629 ],
[-0.11419173],
[ 0.65623331],
[ 0.1342153 ],
[-0.05273577],
[-0.44047645],
[-0.23370352],
[-0.15489477],
[ 0.03935406],
[-0.19836313],
[-0.99544287],
[ 0.55117708],
[-0.59885556],
[-0.11648145],
[ 0.71632475],
[ 0.16874671],
[-0.67526364],
[-0.23234156],
[ 0.17508182],
[ 0.08158282],
[-0.15935606],
[-0.99951506],
[ 0.28044924],
[-0.15433758],
[-0.62045538],
[-0.79123241],
[ 0.58547592],
[ 0.61869878],
[ 0.10728819],
[ 0.43084788],
[-0.1509831 ],
[-0.18121251],
[ 0.77610898],
[ 0.42763799],
[ 0.29294077],
[-0.18773636],
[-0.52435344],
[ 0.05225879],
[-0.48631665],
[ 0.44314644],
[ 0.7868011 ],
[ 0.36901325],
[ 0.40812942],
[ 0.14549929],
[-0.60122514],
[ 0.03218338],
[ 1.33894801],
[ 0.25125071],
[-0.95109797],
[ 0.33375832],
[ 0.59405452],
[-0.34858713],
[-0.30604652],
[ 0.23997733],
[ 0.02507449],
[ 0.14913729],
[-0.19059059],
[-0.30665812],
[ 0.56229717],
[ 0.01064511],
[-0.82806253],
[-0.14076799],
[-0.13117707],
[ 0.79744756],
[-0.13299817],
[ 0.06150784],
[ 0.96350777],
[-0.08150166],
[ 0.35942987],
[-0.87045056],
[ 0.35863313],
[-0.19681925],
[ 0.3388117 ],
[-0.49157837],
[-0.25832838],
[ 0.08012209],
[ 0.6801765 ],
[ 0.6910466 ],
[ 0.18716562],
[ 0.9650228 ],
[ 0.58523625],
[-0.61542195],
[ 0.65314984],
[-0.3882488 ]], dtype=float32)
Content source: AlphaSmartDog/DeepLearningNotes
Similar notebooks: