接下來作者將利用KAN進行對鳶尾花的分類實現,體現它相對于MLP無法比擬的可解釋性、交互性特點,當然KAN也有其缺點就是目前版本訓練速度較慢

代碼實現

數據讀取

import pandas as pd
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import numpy as np
iris = load_iris()
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['target'] = iris.target
df = iris_df[iris_df['target'] != 2] # 只要0和1完成一個二分類問題
df.head()

這里將數據簡單梳理為二分類問題,并且將這個分類問題看作回歸問題,去探討不同輸出維度下的KAN

KAN輸出維度=1

from sklearn.model_selection import train_test_split
import torch
train_input, test_input, train_label, test_label = train_test_split(df.iloc[:, 0:4], df['target'],
test_size=0.2, random_state=42, stratify=df['target'])

# 將 DataFrame 和 Series 轉換為 np.array
train_input = train_input.to_numpy()
test_input = test_input.to_numpy()
train_label = train_label.to_numpy()
test_label = test_label.to_numpy()
# 轉換為pytorch張量
dataset = {}
dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label[:,None])
dataset['test_label'] = torch.from_numpy(test_label[:,None])

分割數據集且將原始的 DataFrame 數據轉換為適合在 PyTorch 中使用的張量形式

from kan import KAN
model = KAN(width=[4,1], grid=3, k=3)
# 初始化繪制KAN
model(dataset['train_input']);
model.plot(beta=100)

這里創建一個KAN:4D輸入,1D輸出,沒有隱藏的神經元,三次樣條 (k=3),3個網格間隔 (grid=3),如果要添加隱藏的神經元在width中添加既可,它表示每層中的神經元數,例如,[2,5,5,3] 表示 2D 輸入,3D 輸出,具有 2 層 5 個隱藏神經元,創建這樣的模型后對其可視化,當前這個模型還沒有進行訓練,接下來訓練這個模型

# 定義訓練集準確率計算函數
def train_acc():
# 使用模型對訓練輸入進行預測,取預測值的第一個輸出并四舍五入
# 將預測值與訓練標簽進行比較,計算準確率
return torch.mean((torch.round(model(dataset['train_input'])[:, 0]) == dataset['train_label'][:, 0]).float())

# 定義測試集準確率計算函數
def test_acc():
# 使用模型對測試輸入進行預測,取預測值的第一個輸出并四舍五入
# 將預測值與測試標簽進行比較,計算準確率
return torch.mean((torch.round(model(dataset['test_input'])[:, 0]) == dataset['test_label'][:, 0]).float())

# 訓練模型,使用LBFGS優化器,訓練20步,計算訓練和測試集的準確率
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc))
model.plot()

定義了兩個函數 train_acc() 和 test_acc() 分別用于計算訓練集和測試集上模型的準確率,然后使用 LBFGS 優化器對模型進行訓練,訓練步數為 20 步,并同時計算并輸出訓練和測試集的準確率,最后對模型進行可視化,對比模型初始可視化可以發現激活函數明顯不一樣了,這就是KAN對激活函數學習的一個結果,接下來我們把這個模型進行解釋性輸出

lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','tan','abs']
model.auto_symbolic(lib=lib)
formula = model.symbolic_formula()[0][0]
formula

可以發現KAN模型相對其其它深度學習框架,它可以輸出一個具體的公式,當然這個KAN是單輸出所以只有一個公式,通過這個公式它不在是一個黑箱模型,而是可以被我們所解釋的模型,實際上把相應的X值輸入公式并進行四舍五入返回的值就是0或1也就是我們的實際類別,接下來通過這個公式來輸出在訓練集、測試集上的模型精確度

def acc(formula, X, y):
batch = X.shape[0] # 獲取批量大小
correct = 0 # 初始化正確預測的數量
for i in range(batch):
# 構建替換字典,將 x_1, x_2, x_3, x_4 替換為當前樣本的值
subs_dict = {'x_1': X[i, 0], 'x_2': X[i, 1], 'x_3': X[i, 2], 'x_4': X[i, 3]}
# 使用給定的公式對當前樣本進行預測,并將結果轉換為浮點數
prediction = float(formula.subs(subs_dict))
# 四舍五入預測值,與真實標簽進行比較
if np.round(prediction) == y[i, 0]:
correct += 1
# 計算準確率
accuracy = correct / batch
return accuracy

# 計算訓練集和測試集的準確率
train_accuracy = acc(formula, dataset['train_input'], dataset['train_label'])
test_accuracy = acc(formula, dataset['test_input'], dataset['test_label'])
print('train acc of the formula:', train_accuracy)
print('test acc of the formula:', test_accuracy)

通過準確率可知這個單輸出的二分類KAN模型,表現的很好只是在訓練集上出現了一點錯誤,接下來我們重新去構建一個二輸出的KAN模型

KAN輸出維度=2

dataset = {}
dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label).type(torch.long)
dataset['test_label'] = torch.from_numpy(test_label).type(torch.long)

model = KAN(width=[4,2], grid=3, k=3)
model(dataset['train_input']);
model.plot(beta=100)

這個模型相對于第一個模型只去修改了它的輸出維數為二,同樣還是把它看作是一個回歸模型

def train_acc():
return torch.mean((torch.argmax(model(dataset['train_input']), dim=1) == dataset['train_label']).float())

def test_acc():
return torch.mean((torch.argmax(model(dataset['test_input']), dim=1) == dataset['test_label']).float())

results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss())
model.plot()

同樣是對激活函數進行學習,并可視化

lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
model.auto_symbolic(lib=lib)
formula1, formula2 = model.symbolic_formula()[0]
formula1
formula2

這是一個輸出維數為二的KAN模型相應的它的輸出都有與它一一對應的數學公式來進行解釋

def acc(formula1, formula2, X, y):
batch = X.shape[0]
correct = 0
for i in range(batch):
logit1 = np.array(formula1.subs('x_1', X[i,0]).subs('x_2', X[i,1]).subs('x_3', X[i,2]).subs('x_4', X[i,3])).astype(np.float64)
logit2 = np.array(formula2.subs('x_1', X[i,0]).subs('x_2', X[i,1]).subs('x_3', X[i,2]).subs('x_4', X[i,3])).astype(np.float64)
correct += (logit2 > logit1) == y[i]
return correct/batch

print('train acc of the formula:', acc(formula1, formula2, dataset['train_input'], dataset['train_label']))
print('test acc of the formula:', acc(formula1, formula2, dataset['test_input'], dataset['test_label']))

相應的計算這個KAN模型的準確率,可以發現這個輸出維數為二的KAN比輸出維度為一的KAN要好,這個KAN模型在這個數據集上百分比預測正確,這里利用的是預測結果(即 logit2 > logit1 的布爾值)與真實標簽 y[i] 相等,則返回 True(1),否則返回 False(0),來進行準確率計算,到這里就完成了這個分類模型的構建,讀者也可以嘗試對所有數據集進行三分類KAN構建,下面是作者對完整鳶尾花數據進行構建的KAN模型可視化

本文章轉載微信公眾號@Python機器學習AI

上一篇:

基于熵權法的TOPSIS模型

下一篇:

SOFTS模型的單特征時間序列預測實現

我們有何不同?

API服務商零注冊

多API并行試用

數據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創意新穎性、情感共鳴力、商業轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費