七月论文审稿GPT第4版:通过paper-review数据集微调Mixtral-8x7b

发布于:2024-03-29 ⋅ 阅读:(42) ⋅ 点赞:(0)

模型训练

Mixtral-8x7b地址:魔搭社区

GitHub: hiyouga/LLaMA-Factory: Unify Efficient Fine-tuning of 100+ LLMs (github.com)

环境配置

git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10
conda activate llama_factory
cd /root/path/LLaMA-Factory
pip install -r requirements.txt

有些得单独版本对齐,本人使用的是cuda11.8

pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
pip install bitsandbytes==0.41.3
# 下载对应版本 https://github.com/Dao-AILab/flash-attention/releases
pip install flash_attn-2.5.2+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

训练代码

python src/train_bash.py \
    --stage sft \
    --do_train True \
    --model_name_or_path /root/weights/Mixtral-8x7B-Instruct-v0.1 \
    --finetuning_type lora \
	--quantization_bit 4 \
    --template mistral \
    --flash_attn True \
    --dataset_dir data \
    --dataset paper_review_data \
    --cutoff_len 12288 \
    --learning_rate 5e-05 \
    --num_train_epochs 3.0 \
    --max_samples 1000000 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --max_grad_norm 0.3 \
    --logging_steps 10 \
	--warmup_steps 0 \
	--lora_rank 128 \
    --save_steps 1000 \
    --lora_dropout 0.05 \
    --lora_target q_proj,o_proj,k_proj,v_proj,down_proj,gate_proj,up_proj \
    --output_dir saves/Mixtral-8x7B-Chat/lora/train_2024-03-23 \
    --fp16 True \
    --plot_loss True

模型推理

部署API接口

这里使用lora执行src/api_demo.py时会出现一个问题:

NotImplementedError: Cannot copy out of meta tensor; no data! · Issue #2940 · hiyouga/LLaMA-Factory (github.com)

解决方案:训练时使用了--quantization_bit 4 和 --flash_attn True,这里也要使用统一的才行。

CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api_demo.py \
    --model_name_or_path /root/weights/Mixtral-8x7B-Instruct-v0.1 \
    --adapter_name_or_path /root/path/saves/Mixtral-8x7B-Chat/lora/train_train_2024-03-23 \
    --template mistral \
    --finetuning_type lora \
    --quantization_bit 4 \
	--flash_attn True

推理所需显存为34318MiB

调用API接口

更多见七月的《大模型商用项目之审稿GPT微调实战》


网站公告

今日签到

点亮在社区的每一天
去签到