!pip install -qqq torch==2.0.1 --progress-bar off
!pip install -qqq transformers==4.32.1 --progress-bar off
!pip install -qqq datasets==2.14.4 --progress-bar off
!pip install -qqq peft==0.5.0 --progress-bar off
!pip install -qqq bitsandbytes==0.41.1 --progress-bar off
!pip install -qqq trl==0.7.1 --progress-bar off

Bitsandbytes庫將協助我們以4位精度加載模型。Peft庫則為我們提供了運用LoRA技術的工具。而trl庫則提供了一個Trainer類,我們將借助這個類來微調模型。

接下來,讓我們添加所需的導入:

import json
import re
from pprint import pprint

import pandas as pd
import torch
from datasets import Dataset, load_dataset
from huggingface_hub import notebook_login
from peft import LoraConfig, PeftModel
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from trl import SFTTrainer

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "meta-llama/Llama-2-7b-hf"

我們將采用的模型是Meta AI提供的Llama 2的7b版本。這是一個基礎模型(未進行指令調整),因為我們不打算在對話模式下應用它。

數據預處理

我們將使用的數據集來源于客戶與Twitter上的支持代理之間的對話。這些數據由Salesforce提供,并在HuggingFace數據集平臺上以6個數據集樞紐的形式公開可用。該數據集共包含1099個對話,其中879個用于訓練,110個用于驗證,另有110個用于測試。接下來,我們將加載這個數據集。

dataset = load_dataset("Salesforce/dialogstudio", "TweetSumm")
dataset
DatasetDict({
train: Dataset({
features: ['original dialog id', 'new dialog id', 'dialog index',
'original dialog info', 'log', 'prompt'],
num_rows: 879
})
validation: Dataset({
features: ['original dialog id', 'new dialog id', 'dialog index',
'original dialog info', 'log', 'prompt'],
num_rows: 110
})
test: Dataset({
features: ['original dialog id', 'new dialog id', 'dialog index',
'original dialog info', 'log', 'prompt'],
num_rows: 110
})
})

讓我們看一下 HuggingFace Datasets Hub 上的預覽:

數據集預覽
DialogSumm 數據集預覽

我們主要關注兩個領域的信息:

接下來,我們將編寫一個函數,用于從數據點中提取這些信息。

def generate_text(data_point):
summaries = json.loads(data_point["original dialog info"])["summaries"][
"abstractive_summaries"
]
summary = summaries[0]
summary = " ".join(summary)

conversation_text = create_conversation_text(data_point)
return {
"conversation": conversation_text,
"summary": summary,
"text": generate_training_prompt(conversation_text, summary),
}

摘要是從數據點的結構中提取的。下面是一個示例摘要:

Customer enquired about his Iphone and Apple watch which is not showing his anysteps/activity and health activities. Agent is asking to move to DM and lookinto it.

讓我們來看看 create_conversation_text 函數:

def create_conversation_text(data_point):
text = ""
for item in data_point["log"]:
user = clean_text(item["user utterance"])
text += f"user: {user.strip()}\n"

agent = clean_text(item["system response"])
text += f"agent: {agent.strip()}\n"

return text

def clean_text(text):
text = re.sub(r"http\S+", "", text)
text = re.sub(r"@[^\s]+", "", text)
text = re.sub(r"\s+", " ", text)
return re.sub(r"\^[^ ]+", "", text)

該函數將數據點的log字段中的對話文本放在一起。它還通過刪除URL、提及和額外空格來清理文本。下面是一個對話示例:

user: So neither my iPhone nor my Apple Watch are recording my steps/activity,and Health doesn't recognise either source anymore for some reason. Any ideas?please read the above. agent: Let's investigate this together. To start, can youtell us the software versions your iPhone and Apple Watch are running currently?user: My iPhone is on 11.1.2, and my watch is on 4.1. agent: Thank you. Have youtried restarting both devices since this started happening? user: I've restartedboth, also un-paired then re-paired the watch. agent: Got it. When did you firstnotice that the two devices were not talking to each other. Do the two devicescommunicate through other apps such as Messages? user: Yes, everything seemsfine, it's just Health and activity. agent: Let's move to DM and look into thisa bit more. When reaching out in DM, let us know when this first startedhappening please. For example, did it start after an update or after installinga certain app?

最后一部分是prompt生成函數(我們將在訓練過程中使用的文本):

DEFAULT_SYSTEM_PROMPT = """
Below is a conversation between a human and an AI agent. Write a summary of the conversation.
""".strip()

def generate_training_prompt(
conversation: str, summary: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
return f"""### Instruction: {system_prompt}

### Input:
{conversation.strip()}

### Response:
{summary}
""".strip()

我們將使用Alapaca風格的提示格式。下面是我們示例中的提示:

### Instruction:

Below is a conversation between a human and an AI agent. Write a summary of the
conversation.

### Input:

user: So neither my iPhone nor my Apple Watch are recording my steps/activity,
and Health doesn't recognise either source anymore for some reason. Any ideas?
please read the above. agent: Let's investigate this together. To start, can you
tell us the software versions your iPhone and Apple Watch are running currently?
user: My iPhone is on 11.1.2, and my watch is on 4.1. agent: Thank you. Have you
tried restarting both devices since this started happening? user: I've restarted
both, also un-paired then re-paired the watch. agent: Got it. When did you first
notice that the two devices were not talking to each other. Do the two devices
communicate through other apps such as Messages? user: Yes, everything seems
fine, it's just Health and activity. agent: Let's move to DM and look into this
a bit more. When reaching out in DM, let us know when this first started
happening please. For example, did it start after an update or after installing
a certain app?

### Response:

Customer enquired about his Iphone and Apple watch which is not showing his any
steps/activity and health activities. Agent is asking to move to DM and look
into it.

我們現在可以使用helper函數來處理整個數據集:

def process_dataset(data: Dataset):
return (
data.shuffle(seed=42)
.map(generate_text)
.remove_columns(
[
"original dialog id",
"new dialog id",
"dialog index",
"original dialog info",
"log",
"prompt",
]
)
)

這里我們使用datasets庫來對數據進行混洗,并將generate_text函數應用于每個數據點。同時,該函數還會刪除我們不需要的字段。接下來,我們將把這個處理流程應用到數據集的所有分割部分上:

dataset["train"] = process_dataset(dataset["train"])
dataset["validation"] = process_dataset(dataset["validation"])
dataset
DatasetDict({
train: Dataset({
features: ['conversation', 'summary', 'text'],
num_rows: 879
})
validation: Dataset({
features: ['conversation', 'summary', 'text'],
num_rows: 110
})
test: Dataset({
features: ['original dialog id', 'new dialog id', 'dialog index', 'original dialog info', 'log', 'prompt'],
num_rows: 110
})
})

稍后我們將處理測試子集。

模型

我們將采用Llama 2的基礎7b版本。為了以4位精度加載它,我們將使用bitsandbytes庫。首先,我們需要登錄HuggingFace Hub(訪問可能需要權限):

notebook_login()

接下來,我們將編寫一個helper函數來加載模型和tokenizer:

def create_model_and_tokenizer():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
use_safetensors=True,
quantization_config=bnb_config,
trust_remote_code=True,
device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

return model, tokenizer

為了進行4位量化,我們選擇使用4位歸一化浮點數(nf)。同時,我們啟用use_safetensors選項以確保采用安全的張量格式進行加載。接下來,我們將下載模型和分詞器。

model, tokenizer = create_model_and_tokenizer()
model.config.use_cache = False

transformers庫與各種量化庫實現了良好的集成。我們可以方便地檢查模型的量化配置,以了解其具體細節。

model.config.quantization_config.to_dict()
{
'quant_method': <QuantizationMethod.BITS_AND_BYTES: 'bitsandbytes'>,
'load_in_8bit': False,
'load_in_4bit': True,
'llm_int8_threshold': 6.0,
'llm_int8_skip_modules': None,
'llm_int8_enable_fp32_cpu_offload': False,
'llm_int8_has_fp16_weight': False,
'bnb_4bit_quant_type': 'nf4',
'bnb_4bit_use_double_quant': False,
'bnb_4bit_compute_dtype': 'float16'
}

最后一個組件是 QLora 配置:

lora_r = 16
lora_alpha = 64
lora_dropout = 0.1
lora_target_modules = [
"q_proj",
"up_proj",
"o_proj",
"k_proj",
"down_proj",
"gate_proj",
"v_proj",
]

peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=lora_target_modules,
bias="none",
task_type="CAUSAL_LM",
)

我們設置了更新矩陣的秩(r = 16)和丟棄率(lora_dropout = 0.05)。權重矩陣通過LORA的α參數進行縮放。

訓練

我們將使用 Tensorboard 來監控訓練過程。讓我們開始吧:

OUTPUT_DIR = "experiments"

%load_ext tensorboard
%tensorboard --logdir experiments/runs

接下來,我們將設置訓練參數:

training_arguments = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
optim="paged_adamw_32bit",
logging_steps=1,
learning_rate=1e-4,
fp16=True,
max_grad_norm=0.3,
num_train_epochs=2,
evaluation_strategy="steps",
eval_steps=0.2,
warmup_ratio=0.05,
save_strategy="epoch",
group_by_length=True,
output_dir=OUTPUT_DIR,
report_to="tensorboard",
save_safetensors=True,
lr_scheduler_type="cosine",
seed=42,
)

大多數設置都相當直觀。我們正在采用以下配置:

我們將使用的訓練器類來自trl庫,它是一個基于transformers庫中Trainer類的封裝。除了提供標準的訓練功能外,我們還將通過peft_configdataset_text_field這兩個選項進行配置。其中,dataset_text_field選項需要指定用于訓練提示的數據集中的字段名稱。

trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=4096,
tokenizer=tokenizer,
args=training_arguments,
)

讓我們開始訓練:

trainer.train()
步步驟訓練損失驗證丟失
221.9064001.921726
441.8235001.881039
661.6770001.861916
881.7746001.853609
1101.6468001.852111

我們來看看 Tensorboard 中的訓練指標:

train/eval 損失
Train (左) 和 eval (右) 損失

驗證和訓練損失均顯著下降。現在,讓我們來保存這個模型。

trainer.save_model()

這將僅保存 QLoRA 適配器權重和模型配置。你 仍然需要加載原始模型和 tokenizer。

將 QLoRA 適配器與 Llama 2 合并(可選)

您可以將QLoRA適配器與原始模型進行合并,從而生成一個可用于推理的單一整合模型。以下是實現這一步驟的方法:

from peft import AutoPeftModelForCausalLM

trained_model = AutoPeftModelForCausalLM.from_pretrained(
OUTPUT_DIR,
low_cpu_mem_usage=True,
)

merged_model = model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True)
tokenizer.save_pretrained("merged_model")

現在可以從merged_model目錄加載您的模型和標記器。

評估

我們來看看測試集中的一些預測結果。為了這些預測,我們將使用generate_prompt函數來為模型生成相應的提示符。

def generate_prompt(
conversation: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
return f"""### Instruction: {system_prompt}

### Input:
{conversation.strip()}

### Response:
""".strip()

讓我們構建示例(摘要、對話和提示):

examples = []
for data_point in dataset["test"].select(range(5)):
summaries = json.loads(data_point["original dialog info"])["summaries"][
"abstractive_summaries"
]
summary = summaries[0]
summary = " ".join(summary)
conversation = create_conversation_text(data_point)
examples.append(
{
"summary": summary,
"conversation": conversation,
"prompt": generate_prompt(conversation),
}
)
test_df = pd.DataFrame(examples)
test_df
總結談話提示
0客戶抱怨監視列表是…用戶:我的監視列表沒有更新新的EP.### 說明:以下是我們之間的對話…
1客戶詢問ACC鏈接到…用戶:嗨,我的帳戶被鏈接到一個舊號碼….### 說明:以下是我們之間的對話…
2客戶抱怨新的更新…用戶:iOS 11的新更新很糟糕。我甚至不能…### 說明:以下是我們之間的對話…
3客戶投訴包裹服務…用戶:去他媽的你和你那狗屁的包裹服務……### 說明:以下是我們之間的對話…
4顧客說他被困在斯泰內斯了.用戶:卡在斯泰內斯等待閱讀測試.### 說明:以下是我們之間的對話…

最后,讓我們添加一個helper函數來總結給定的提示:

def summarize(model, text: str):
inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
inputs_length = len(inputs["input_ids"][0])
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0001)
return tokenizer.decode(outputs[0][inputs_length:], skip_special_tokens=True)

讓我們加載基本模型和微調模型:

model, tokenizer = create_model_and_tokenizer()
trained_model = PeftModel.from_pretrained(model, OUTPUT_DIR)

讓我們看看測試集中的第一個示例:

example = test_df.iloc[0]
print(example.conversation)
user: My watchlist is not updating with new episodes (past couple days). Anyidea why? agent: Apologies for the trouble, Norlene! We're looking into this. Inthe meantime, try navigating to the season / episode manually. user: Triedlogging out/back in, that didn't help agent: Sorry! ?? We assure you that ourteam is working hard to investigate, and we hope to have a fix ready soon! user:Thank you! Some shows updated overnight, but others did not... agent: Wedefinitely understand, Norlene. For now, we recommend checking the show page forthese shows as the new eps will be there user: As of this morning, the problemseems to be resolved. Watchlist updated overnight with all new episodes. Thankyou for your attention to this matter! I love Hulu ?? agent: Awesome! That'swhat we love to hear. If you happen to need anything else, we'll be here tosupport! ??

以下是數據集的摘要:

print(example.summary)

原始摘要

Customer is complaining that the watchlist is not updated with new episodes frompast two days. Agent informed that the team is working hard to investigate toshow new episodes on page.

我們可以從Llama 2模型中得到總結:

summary = summarize(model, example.prompt)
pprint(summary)

基本模型摘要

('\n' '\n' '### Input:\n' 'user: My watchlist is not updating with new episodes(past couple days). Any ' 'idea why?\n' "agent: Apologies for the trouble,Norlene! We're looking into this. In the " 'meantime, try navigating to theseason / episode manually.\n' 'user: Tried logging out/back in, that didn'thelp\n' 'agent: Sorry! ?? We assure you that our team is working hard toinvestigate, ' 'and we hope to have a fix ready soon!\n' 'user: Thank you! Someshows updated overnight, but others did not...\n' 'agent: We definitelyunderstand, Norlene. For now, we recommend checking the ' 'show page for theseshows as the new eps will be there\n' 'user: As of this morning, the problemseems to be resolved. Watchlist ' 'updated overnight with all new episodes.Thank you for your attention to ' 'this matter! I love Hulu ??\n' "agent:Awesome! That's what we love to hear. If you happen to need anything " "else,we'll be here to support! ??\n" '\n' '### Output:\n' '\n' '### Input:\n' 'user:My watchlist')

這個結果似乎不太理想。讓我們來看看微調后的模型會給出什么樣的輸出

summary = summarize(trained_model, example.prompt)
pprint(summary)

微調模型摘要

('\n' 'Customer is complaining that his watchlist is not updating with new ''episodes. Agent updated that they are looking into this and also informed ''that they will be here to support.\n' '\n' '### Input:\n' 'Customer iscomplaining that his watchlist is not updating with new ' 'episodes. Agentupdated that they are looking into this and also informed ' 'that they will behere to support.\n' '\n' '### Response:\n' 'Customer is complaining that hiswatchlist is not updating with new ' 'episodes. Agent updated that they arelooking into this and also informed ' 'that they will be here to support.\n''\n' '### Input:\n' 'Customer is complaining that his watchlist is not updatingwith new ' 'episodes. Agent updated that they are looking into this and alsoinformed ' 'that they will be here to support.\n' '\n' '### Response:\n''Customer is complaining that his watchlist is not updating with new ''episodes. Agent updated that they are looking into this and also informed ''that they will be here to support.\n' '\n' '### Input:\n' 'Customer iscomplaining that his watchlist is not updating with new ' 'episodes. Agentupdated that they are looking into this and also informed ' 'that they will behere to support.\n' '\n' '### Response:\n' 'Customer is complaining that hiswatchlist is')

確實有所改進,但我們先只關注第一段的內容:

pprint(summary.strip().split("\n")[0])

微調模型摘要(已清理)

Customer is complaining that his watchlist is not updating with new episodes.Agent updated that they are looking into this and also informed that they willbe here to support.

這看起來好多了,確實給出了一個很好的總結。接下來,我們試試下一個示例吧

example = test_df.iloc[1]
print(example.conversation)
user: hi , my Acc was linked to an old number. Now I'm asked to verify my Acc ,where a code / call wil be sent to my old number. Any way that I can link my Accto my current number? Pls help agent: Hi there, we are here to help. We willhave a specialist contact you about changing your phone number. Thank you. user:Thanks. Hope to get in touch soon agent: That is no problem. Please let us knowif you have any further questions in the meantime. user: Hi sorry , is it for myaccount : **email** agent: Can you please delete this post as it does havepersonal info in it. We have updated your Case Manager who will be following upwith you shortly. Feel free to DM us anytime with any other questions orconcerns 2/2 user: Thank you agent: That is no problem. Please do not hesitateto contact us with any further questions. Thank you.

原始摘要

Customer is asking about the ACC to link to the current number. Agent says thatthey have updated their case manager.

最初的總結非常簡潔,讓我們看看基本模型產生了什么:

基本模型摘要

('\n' 'The conversation between a human and an AI agent is about changing thephone ' 'number of an account. The human asks if there is any way to link theaccount ' 'to a new phone number, and the agent replies that they will have a ''specialist contact the user about changing the phone number. The human ''thanks the agent and hopes to get in touch soon. The agent then asks the ''human to delete the post as it contains personal information. The human ''replies that they will delete the post. The agent then thanks the human for ''their cooperation and closes the conversation.\n' '\n' '### Output:\n' 'Theconversation between a human and an AI agent is about changing the phone ''number of an account. The human asks if there is any way to link the account ''to a new phone number, and the agent replies that they will have a ''specialist contact the user about changing the phone number. The human ''thanks the agent and hopes to get in touch soon. The agent then asks the ''human to delete the post as it contains personal information. The human ''replies that they will delete the post. The agent then thanks the human for ''their cooperation and closes the conversation.\n' '\n' '### Output:\n' 'Theconversation between a human and an AI agent is')

與第一個例子相比,看起來要好得多,但仍然很長。讓我們看看微調后的模型會產生什么:

微調模型摘要(已清理)

Customer is asking to link his account to his current number. Agent updated thatthey will have a specialist contact him about changing his phone number.

非常完美,簡短而中肯。最后一個例子:

example = test_df.iloc[2]print(example.conversation)
user: the new update ios11 sucks. I can't even use some apps on my phone. agent:We want your iPhone to work properly, and we are here for you. Which apps aregiving you trouble, and which iPhone? user: 6s. Words with friends Words proagent: Do you see app updates in App Store &gt; Updates? Also, are you using iOS11.0.3? user: I am using 11.0.3 and there are no updates for words pro that Ican find agent: Thanks for checking. Next, what happens in that app that makesit unusable? user: It's says it's not compatible. agent: Thanks for confirmingthis. Send us a DM and we'll work from there:

原始摘要

Customer is complaining about the new updates IOS11 and can't even use some appson phone. Agent asks to send a DM and work from there URL.

同樣,讓我們看看基本模型摘要:

基本模型摘要

('\n' '\n' '### Input:\n' 'user: the new update ios11 sucks. I can't even usesome apps on my phone.\n' 'agent: We want your iPhone to work properly, and weare here for you. Which ' 'apps are giving you trouble, and which iPhone?\n''user: 6s. Words with friends Words pro\n' 'agent: Do you see app updates in AppStore &gt; Updates? Also, are you using ' 'iOS 11.0.3?\n' 'user: I am using11.0.3 and there are no updates for words pro that I can ' 'find\n' 'agent:Thanks for checking. Next, what happens in that app that makes it ''unusable?\n' 'user: It's says it's not compatible.\n' "agent: Thanks forconfirming this. Send us a DM and we'll work from there:\n" '\n' '### Output:\n''\n' '### Input:\n' 'user: the new update ios11 sucks. I can't even use someapps on my phone.\n' 'agent: We want your iPhone to work properly, and we arehere for you. Which ' 'apps are giving you trouble, and which iPhone?\n' 'user:6s. W')

它基本上是對話的副本。讓我們看看微調后的模型給我們帶來了什么:

微調模型摘要(已清理)

Customer is complaining about the new update ios11 sucks. Agent updated to senda DM and they will work from there.

我真的很喜歡這個新總結,比起原來的總結要好得多。它簡潔明了地表達了對話的主要思想,即關于iOS 11的評價(是否很爛)

結論

Llama 2的微調為我們提供了一種生成對話摘要的方法。與基礎模型相比,微調后的模型能夠生成更短、更精煉且切中要害的摘要。我會說,這次微調成功地滿足了我們的特定用例需求。

原文鏈接:https://www.mlexpert.io/blog/fine-tuning-llama-2-on-custom-dataset

上一篇:

如何追蹤AI API的使用情況:終極指南

下一篇:

利用LangChain與OpenLLM構建基于自定義知識庫的聊天機器人
#你可能也喜歡這些API文章!

我們有何不同?

API服務商零注冊

多API并行試用

數據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創意新穎性、情感共鳴力、商業轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費