
實時航班追蹤背后的技術:在線飛機追蹤器的工作原理
library(dplyr) #數據處理使用
library(data.table) #數據讀取使用
library(randomForest) #RF模型使用
library(caret) # 調參和計算模型評價參數使用
library(pROC) #繪圖使用
library(ggplot2) #繪圖使用
library(ggpubr) #繪圖使用
library(ggprism) #繪圖使用
# 讀取數據
data <- fread("./RF_data.txt",data.table = F) # 替換為你的數據文件名或路徑
數據長這個樣子,一共35727行,214列。每一行代表一個樣本,第一列是樣本標簽malignant
或normal
,后面213列是213個特征。我們想根據213個特征,使用RF訓練出一個能夠對樣本進行精準分類的模型。
構建RF模型
# 分割數據為訓練集和測試集
set.seed(123) # 設置隨機種子,保證結果可復現
split <- sample.split(data$type, SplitRatio = 0.8) # 將數據按照指定比例分割
train_data <- subset(data, split == TRUE) # 訓練集
test_data <- subset(data, split == FALSE) # 測試集
# 定義訓練集特征和目標變量
X_train <- train_data[, -1]
y_train <- as.factor(train_data[, 1])
# 創建隨機森林分類模型
model <- randomForest(x = X_train, y = y_train, ntree = 100)
# 輸出默認參數下的模型性能
print(model)
這是一個沒有經過調參的模型結果,盡管看起來模型已經很不錯了,但我們還是繼續進行調參,看一下模型效果能夠上升多少。
有至少兩個參數需要進行測序,分別是mtry
和ntree
。
由于caret包只提供了mtry
參數的調節,關于ntree
參數的調節我們這里手動進行。
mtry
參數調節
# 進行參數調優
# 創建訓練控制對象
ctrl <- trainControl(method = "cv", number = 5) #使用五折交叉驗證,也可以選擇10折交叉驗證。
# 定義參數網格
grid <- expand.grid(mtry = c(2, 4, 6)) # 每棵樹中用于分裂的特征數量,這里只是隨便給的測試,主要為了介紹如何調參,并非最優選擇。
# 使用caret包進行調參
rf_model <- train(x = X_train, y = y_train,
method = "rf",
trControl = ctrl,
tuneGrid = grid)
# 輸出最佳模型和參數
print(rf_model)
使用準確性來選擇最佳模型。該模型最終mtry
值為mtry = 6。
ntree
參數調節
# 調整Caret沒有提供的參數
# 如果我們想調整的參數Caret沒有提供,可以用下面的方式自己手動調參。
# 用剛剛調參的最佳mtry值固定mtry
grid <- expand.grid(mtry = c(6)) # 每棵樹中用于分裂的特征數量
# 定義模型列表,存儲每一個模型評估結果
modellist <- list()
# 調整的參數是決策樹的數量
for (ntree in c(100,200, 300)) {
set.seed(123)
fit <- train(x = X_train, y = y_train, method="rf",
metric="Accuracy", tuneGrid=grid,
trControl=ctrl, ntree=ntree)
key <- toString(ntree)
modellist[[key]] <- fit
}
# compare results
results <- resamples(modellist)
# 輸出最佳模型和參數
summary(results)
從準確性可以看出,ntree = 200是最佳的。這樣我們就完成了調參,最佳的參數組合是mtry = 6,ntree = 200。
使用最佳參數訓練模型
# 使用最佳參數訓練最終模型
final_model <- randomForest(x = X_train, y = y_train,mtry = 6,ntree = 200)
# 輸出最終模型
print(final_model)
從結果可以看出,經過調參的模型比初始模型好了一點點。
這里使用caret包包中的函數來輸出模型的評價指標,想手動計算可以參考邏輯回歸(LR)的推文。
# 在測試集上進行預測
X_test <- test_data[, -1]
y_test <- as.factor(test_data[, 1])
test_predictions <- predict(final_model, newdata = test_data)
# 計算模型指標
confusion_matrix <- confusionMatrix(test_predictions, y_test)
accuracy <- confusion_matrix$overall["Accuracy"]
precision <- confusion_matrix$byClass["Pos Pred Value"]
recall <- confusion_matrix$byClass["Sensitivity"]
f1_score <- confusion_matrix$byClass["F1"]
# 輸出模型指標
print(confusion_matrix)
print(paste("Accuracy:", accuracy))
print(paste("Precision:", precision))
print(paste("Recall:", recall))
print(paste("F1 Score:", f1_score))
從測試集中看,模型表現的也不錯。下面我們來繪制一下混淆矩陣,ROC曲線。
混淆矩陣提供了對分類模型性能的全面評估。它展示了實際類別和預測類別之間的對應關系,可以清晰地看到模型的預測結果中真正例、真反例、假正例和假反例的數量或比例。繪制混淆矩陣可以將復雜的分類結果以直觀的方式展示出來,使得結果更易于理解和解釋。
# 繪制混淆矩陣熱圖
# 將混淆矩陣轉換為數據框
confusion_matrix_df <- as.data.frame.matrix(confusion_matrix$table)
colnames(confusion_matrix_df) <- c("cluster1","cluster2")
rownames(confusion_matrix_df) <- c("cluster1","cluster2")
draw_data <- round(confusion_matrix_df / rowSums(confusion_matrix_df),2)
draw_data$real <- rownames(draw_data)
draw_data <- melt(draw_data)
ggplot(draw_data, aes(real,variable, fill = value)) +
geom_tile() +
geom_text(aes(label = scales::percent(value))) +
scale_fill_gradient(low = "#F0F0F0", high = "#3575b5") +
labs(x = "True", y = "Guess", title = "Confusion matrix") +
theme_prism(border = T)+
theme(panel.border = element_blank(),
axis.ticks.y = element_blank(),
axis.ticks.x = element_blank(),
legend.position="none")
ROC(Receiver Operating Characteristic)曲線和AUC(Area Under the Curve)是評估二分類模型性能常用的指標。
# 繪制ROC曲線需要將預測結果以概率的形式輸出
test_predictions <- predict(final_model, newdata = test_data,type = "prob")
# 計算ROC曲線的參數
roc_obj <- roc(response = y_test, predictor = test_predictions[, 2])
roc_auc <- auc(roc_obj)
# 將ROC對象轉換為數據框
roc_data <- data.frame(1 - roc_obj$specificities, roc_obj$sensitivities)
# 繪制ROC曲線
ggplot(roc_data, aes(x = 1 - roc_obj$specificities, y = roc_obj$sensitivities)) +
geom_line(color = "#0073C2FF", size = 1.5) +
geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), linetype = "dashed", color = "gray") +
geom_text(aes(x = 0.8, y = 0.2, label = paste("AUC =", round(roc_auc, 2))), size = 4, color = "black") +
coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
theme_pubr() +
labs(x = "1 - Specificity", y = "Sensitivity") +
ggtitle("ROC Curve") +
theme(plot.title = element_text(size = 14, face = "bold"))+
theme_prism(border = T)
本文章轉載微信公眾號@Bio小菜鳥