Pytorch訓練類型:回顧
目前在準備參加人工智能訓練師的考試,算是給自己一個復習吧, 對應不同的數據有不同的方法來處理以及構建簡單的模型(應付考試夠啦~~ 手動狗頭)
一、習慣將需要用到的python庫在前面都導入進來,其實有些都是跟隨步驟而需求的。
# 導入庫
import pandas as pd import numpy as np import torch import torch.nn as nn from torch.optim import Adam import torchmetrics from torch.utils.data import Dataset, random_split, DataLoader
沒有導入畫圖的庫,如前所述,可能三級工考試的題目不需要畫圖展示的(老師是這麽說的~~~~)
#數據處理
class MyDataset(Dataset):
def __init__(self, path):
df = pd.read_csv(path).drop('car_ID', axis=1)
self.y = df['price']
self.X = df.drop('price', axis=1)
cate_cols = self.X.select_dtypes(include='object').columns
self.X = pd.get_dummies(self.X, columns=cate_cols, dtype=int)
self.X = (self.X - self.X.mean()) / self.X.std()
self.y = (self.y - self.y.mean()) / self.y.std()
self.X = torch.tensor(self.X.values, dtype=torch.float32)
self.y = torch.tensor(self.y.values, dtype=torch.float32)
print(f"特徵數: {self.X.shape[1]}, 樣本數: {self.y.shape[0]}")
def __getitem__(self, index):
return self.X[index], self.y[index]
def __len__(self):
return len(self.X)
dataset = MyDataset("CarPrice_Assignment.csv")
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_data, test_data = random_split(dataset, [train_size, test_size])
print(f"訓練集大小:{len(train_data)}, 測試集大小:{len(test_data)}")
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)數據處理是采用官方的標準框架(老師說,其實把數據處理放在框架外面也可以的, 框架内部只要定義一下self.X,self.y及其他固定的寫法)
有時用GPT優化代碼的話, 如果不特別指定使用Pytorch進行劃分數據集,一般都會使用sk-learn來進行,不知道這是爲什麽? 有知道的夥伴可以告知。謝謝
# 模型構建(這段最唬人,其實也都是固定格式, 不過咱不用那些高大上的前輩的模型樣式了,就構建一下能跑起來的模型就好)
import torch.nn as nn
class CarNET(nn.Module):
def __init__(self, input_size):
super(CarNET, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_size, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 1)
)
def forward(self, x):
return self.network(x)
input_size = len(dataset.X[1])
model = CarNET(input_size)
print(f"模型參數量: {sum(p.numel() for p in model.parameters())}")考試主要就是靠架構是否記憶的清楚,參數有沒有弄錯;
# 訓練模型(這部分可以說固定式,但其實不同玩家有不同的寫法,我水平有限,就參考普通寫法了)
lossf = nn.MSELoss()
optim = Adam(model.parameters(), lr=0.0001)
r2_score = torchmetrics.R2Score()
epochs = 10
for epoche in range(epochs):
model.train()
r2_score.reset()
total_loss = 0
for batch_X, batch_y in train_loader:
optim.zero_grad()
output = model(batch_X).squeeze() # 這裏去除一個維度,因爲去掉前面的view(1)
loss_ = lossf(output, batch_y)
loss_.backward()
optim.step()
with torch.no_grad():
total_loss += loss_.item()
r2_score.update(output, batch_y)
avg_loss = total_loss / len(train_loader)
avg_r2 = r2_score.compute().item()
print(f"Epoch {epoche + 1}: Loss: {avg_loss:.4f}, R2 Score: {avg_r2:.4f}")有個地方值得注意的,就是代碼中解釋的那一句,也就是要將數據格式維度變一下,這個在構建模型的時候也是可以進行的 使用
def forward(self, x): output = self.linear(x) return output.squeeze() # 在模型內部處理維度
那以上就是針對回歸的模型構建和訓練,最後再補充一下保存代碼
torch.save(model.state_dict(), 'model.pth') #或者 torch.save(model(), 'model.pth')
如下是AI回復的區別,可想而知,單純一個保存就有這麽多知識點, 學好pytorch是不容易的。
1. torch.save(model.state_dict(), 'model.pth') ✅ 推薦做法
# 保存模型參數torch.save(model.state_dict(), 'model.pth')# 加載時需要先創建模型實例,然後加載參數loaded_model = MyModel() # 必須有模型類定義loaded_model.load_state_dict(torch.load('model.pth'))特點:
只保存模型參數(權重、偏置等)
文件較小
加載時需要模型類的定義
更靈活:可以在不同架構間遷移參數
2. torch.save(model, 'model.pth') ⚠️ 不推薦
# 保存整個模型torch.save(model, 'model.pth')# 加載時直接還原模型實例loaded_model = torch.load('model.pth')特點:
保存整個模型對象(包括結構和參數)
文件較大
加載時不需要模型類定義
可能出現兼容性問題
詳細比較
文件大小對比:
# state_dict方式 - 通常較小torch.save(model.state_dict(), 'state_dict.pth') # 整個模型方式 - 通常較大(包含結構信息)torch.save(model, 'entire_model.pth')
加載方式對比:
# 方法1:state_dict(推薦)model = MyModel()model.load_state_dict(torch.load('model.pth'))model.eval()
# 方法2:整個模型(可能出問題)model = torch.load('model.pth')model.eval()為什麼推薦 state_dict() 方式?
1. 兼容性更好
# 使用state_dict可以輕鬆處理模型架構變化class ImprovedModel(MyModel):
def __init__(self):
super().__init__()
self.new_layer = nn.Linear(10, 5)# 仍然可以加載舊參數improved_model = ImprovedModel()improved_model.load_state_dict(torch.load('old_model.pth'), strict=False)2. 部署友好
# 生產環境中,通常只需要參數class ProductionModel(nn.Module):
# 簡化版本,可能去掉訓練特定的組件
passprod_model = ProductionModel()prod_model.load_state_dict(torch.load('trained_model.pth'))3. 避免序列化問題
整個模型保存可能因為以下原因失敗:
自定義類的定義改變
Python版本差異
PyTorch版本升級