保姆级教程UnslothQwen2-7B手把手教你训练医学推理模型1. 环境准备与快速部署1.1 安装Unsloth环境首先我们需要搭建Unsloth的运行环境。Unsloth是一个开源的LLM微调和强化学习框架能够显著提升训练速度并降低显存消耗。以下是安装步骤# 创建conda环境 conda create -n unsloth_env python3.10 -y conda activate unsloth_env # 安装基础依赖 pip install torch torchvision torchaudio pip install unsloth[colab-new] githttps://github.com/unslothai/unsloth.git1.2 验证安装安装完成后可以通过以下命令验证Unsloth是否安装成功python -m unsloth如果看到类似Unsloth is ready to use!的输出说明安装成功。1.3 安装其他必要依赖pip install transformers datasets trl accelerate bitsandbytes2. 数据集准备与处理2.1 下载医学推理数据集我们将使用medical-o1-reasoning-SFT数据集这是一个包含医学问题和详细推理步骤的高质量数据集from datasets import load_dataset dataset load_dataset( json, data_filesmedical_o1_sft.jsonl, # 替换为你的数据集路径 splittrain, trust_remote_codeTrue, )2.2 数据格式说明数据集包含三个关键字段Question: 医学问题来自医学考试等权威来源Complex_CoT: GPT-4生成的推理步骤Chain-of-ThoughtResponse: 最终答案或建议2.3 数据预处理我们需要将原始数据格式化为适合训练的提示模板train_prompt_style 以下是描述任务的指令以及提供更多上下文的输入。 请写出恰当完成该请求的回答。 在回答之前请仔细思考问题并创建一个逐步的思维链以确保回答合乎逻辑且准确。 ### Instruction: 你是一位在临床推理、诊断和治疗计划方面具有专业知识的医学专家。 请回答以下医学问题。 ### Question: {} ### Response: think {} /think {} def formatting_prompts_func(examples): inputs examples[Question] cots examples[Complex_CoT] outputs examples[Response] texts [] for input, cot, output in zip(inputs, cots, outputs): text train_prompt_style.format(input, cot, output) tokenizer.eos_token texts.append(text) return {text: texts} dataset dataset.map(formatting_prompts_func, batchedTrue)3. 模型加载与配置3.1 加载Qwen2-7B基础模型使用Unsloth优化过的加载方法可以显著减少显存占用from unsloth import FastLanguageModel model, tokenizer FastLanguageModel.from_pretrained( model_nameQwen/Qwen2-7B, # 或本地模型路径 max_seq_length2048, dtypeNone, # 自动选择 load_in_4bitTrue, # 4bit量化节省显存 )3.2 初始推理测试在微调前我们先测试模型的原始表现question 一位61岁的女性长期存在咳嗽或打喷嚏等活动时不自主尿失禁的病史... FastLanguageModel.for_inference(model) inputs tokenizer([prompt_style.format(question, )], return_tensorspt).to(cuda) outputs model.generate( input_idsinputs.input_ids, attention_maskinputs.attention_mask, max_new_tokens1200, use_cacheTrue, ) print(tokenizer.batch_decode(outputs)[0])4. LoRA微调配置4.1 准备LoRA参数Unsloth内置了高效的LoRA实现只需几行代码即可配置FastLanguageModel.for_training(model) model FastLanguageModel.get_peft_model( model, r16, # LoRA秩 target_modules[ q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj, ], lora_alpha16, lora_dropout0, biasnone, use_gradient_checkpointingunsloth, random_state3407, )4.2 训练参数配置from trl import SFTTrainer from transformers import TrainingArguments trainer SFTTrainer( modelmodel, tokenizertokenizer, train_datasetdataset, dataset_text_fieldtext, max_seq_length2048, argsTrainingArguments( per_device_train_batch_size2, gradient_accumulation_steps4, warmup_steps5, learning_rate2e-4, lr_scheduler_typelinear, max_steps60, fp16not torch.cuda.is_bf16_supported(), bf16torch.cuda.is_bf16_supported(), logging_steps10, optimadamw_8bit, weight_decay0.01, seed3407, output_diroutputs, ), )5. 开始训练与模型保存5.1 启动训练trainer.train()5.2 保存微调后的模型训练完成后保存LoRA适配器和合并后的模型# 保存LoRA适配器 model.save_pretrained(medical-cot-lora) # 合并并保存完整模型 merged_model FastLanguageModel.merge_and_unload(model) merged_model.save_pretrained(Medical-COT-Qwen-7B)6. 模型测试与部署6.1 本地测试加载合并后的模型进行测试model, tokenizer FastLanguageModel.from_pretrained( model_nameMedical-COT-Qwen-7B, max_seq_length2048, dtypeNone, load_in_4bitTrue, ) FastLanguageModel.for_inference(model) # 测试问题 question 糖尿病患者出现足部溃疡应该如何治疗 inputs tokenizer([prompt_style.format(question, )], return_tensorspt).to(cuda) outputs model.generate( input_idsinputs.input_ids, attention_maskinputs.attention_mask, max_new_tokens1200, use_cacheTrue, ) print(tokenizer.batch_decode(outputs)[0])6.2 使用Streamlit创建Web界面创建一个简单的Web界面来交互式测试模型import streamlit as st from unsloth import FastLanguageModel # 加载模型 st.cache_resource def load_model(): model, tokenizer FastLanguageModel.from_pretrained( model_nameMedical-COT-Qwen-7B, max_seq_length2048, dtypeNone, load_in_4bitTrue, ) FastLanguageModel.for_inference(model) return model, tokenizer model, tokenizer load_model() # 创建界面 st.title(医学问答系统) question st.text_input(请输入医学问题:) if question: inputs tokenizer([prompt_style.format(question, )], return_tensorspt).to(cuda) outputs model.generate( input_idsinputs.input_ids, attention_maskinputs.attention_mask, max_new_tokens1200, use_cacheTrue, ) response tokenizer.batch_decode(outputs)[0] st.write(response.split(### Response:)[1])7. 总结与建议7.1 关键收获通过本教程我们完成了使用Unsloth高效加载和微调Qwen2-7B模型在医学推理数据集上进行了指令微调实现了包含推理步骤的医学问答系统部署了可交互的Web界面7.2 优化建议数据质量确保医学数据的准确性和专业性提示工程优化提示模板以获得更结构化的输出评估指标建立专业的医学评估标准持续训练可以考虑RLHF进一步优化模型表现7.3 应用场景训练好的模型可用于医学教育辅助临床决策支持患者问答系统医学文献理解获取更多AI镜像想探索更多AI镜像和应用场景访问 CSDN星图镜像广场提供丰富的预置镜像覆盖大模型推理、图像生成、视频生成、模型微调等多个领域支持一键部署。