colab使用本地数据集微调llama3-8b模型

        在Google的Colab上面采用unsloth,trl等库,训练数据集来自Google的云端硬盘,微调llama3-8b模型,进行推理验证模型的微调效果。

        保存模型到Google的云端硬盘可以下载到本地供其它使用。

准备工作:将训练数据集上传到google的云端硬盘根目录下,文件名就叫做train.json

train.json里面的数据格式如下:

[
  {
    "instruction": "你好",
    "output": "你好,我是智能助手胖胖"
  },
  {
    "instruction": "hello",
    "output": "Hello! I am 智能助手胖胖, an AI assistant developed by 丹宇码农. How can I assist you ?"
  }

......

]

采用unsloth库、trl库、transformers等库。

直接上代码:

%%capture
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes

from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/mistral-7b-bnb-4bit",
    "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
    "unsloth/llama-2-7b-bnb-4bit",
    "unsloth/gemma-7b-bnb-4bit",
    "unsloth/gemma-7b-it-bnb-4bit", # Instruct version of Gemma 7b
    "unsloth/gemma-2b-bnb-4bit",
    "unsloth/gemma-2b-it-bnb-4bit", # Instruct version of Gemma 2b
    "unsloth/llama-3-8b-bnb-4bit", # [NEW] 15 Trillion token Llama-3
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)


alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    outputs      = examples["output"]
    texts = []
    for instruction, output in zip(instructions, outputs):
        input = ""
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }
pass

from datasets import load_dataset
#dataset = load_dataset("yahma/alpaca-cleaned", split = "train")
#dataset = dataset.map(formatting_prompts_func, batched = True,)
from google.colab import drive
# 挂载云端硬盘,加载成功后,在左边的文件树中将会多一个 /content/drive/MyDrive/ 目录
drive.mount('/content/drive')


# 加载本地数据集:
# 有instruction和output,input为空字符串
from datasets import load_dataset

data_home = r"/content/drive/MyDrive/"
data_dict = {
    "train": os.path.join(data_home, "train.json"),
    #"validation": os.path.join(data_home, "dev.json"),
}
dataset = load_dataset("json", data_files=data_dict, split = "train")
print(dataset[0])
dataset = dataset.map(formatting_prompts_func, batched = True,)


from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

# 开始微调训练
trainer_stats = trainer.train()

#推理
# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    alpaca_prompt.format(
        "你是谁?", # instruction
        "", # input
        "", # output - leave this blank for generation!
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
tokenizer.batch_decode(outputs)

#此处输出的答案,能明显看到就是自己训练的数据,而不是原来模型的输出。说明微调起作用了


# 保存模型,改成挂接的云硬盘目录也可以保存到google的个人云存储空间,然后打开个人云存储空间下载到本地
model.save_pretrained("lora_model") # Local saving
tokenizer.save_pretrained("lora_model")

# Merge to 16bit
if True: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)

其实可以将.ipynb文件上传到个人云存储空间,双击这个文件就会打开colab,然后依次执行代码即可,随时可以增加、删除、修改,特别方便,还能免费使用GPU、CPU等资源,真的是广大AI爱好者的不错选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/632527.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

抖音小店有订单后怎么发货?实操分享!发货全流程来了

哈喽~我是电商月月 做无货源抖音小店的店铺在出单后怎么发货&#xff1f;今天我就来给大家解答这个问题&#xff0c;其中的注意事项新手商家可以收藏一下&#xff0c;避免犯错 抖音小店的商品出单后&#xff0c;商家在“管理中心-订单管理”页面就能看见所有待处理的订单 一…

IDEA中开发并部署运行WEB项目

IDEA中开发并部署运行WEB项目 1 WEB项目的标准结构2 WEB项目部署的方式3 IDEA中开发并部署运行WEB项目3.1 部署步骤3.2 IDEA关联本地Tomcat 4 IDEA创建web工程5 IDEA部署-运行web项目6 IDEA部署并运行项目的原理 1 WEB项目的标准结构 一个标准的可以用于发布的WEB项目标准结构如…

如何快速实现Java发送邮件功能?

如何使用JAVA发送邮件&#xff1f;&#xff08;不做过多文字解释&#xff0c;详细说明请看代码注释&#xff09; 一&#xff0c;引用相关pom二&#xff0c;实现代码&#xff08;代码参考图片内容&#xff09;三&#xff0c; 接收邮件 一&#xff0c;引用相关pom <dependency…

24长三角C题9页完整建模思路+可执行代码

比赛题目的完整版思路可执行代码数据参考论文都会在第一时间更新上传的&#xff0c;大家可以参考我往期的资料&#xff0c;所有的资料数据以及到最后更新的参考论文都是一次付费后续免费的。注意&#xff1a;&#xff08;建议先下单占坑&#xff0c;因为随着后续我们更新资料数…

信息安全相关内容

信息安全 安全防护体系 安全保护等级 安全防护策略 安全技术基础 安全防护体系 安全防护体系有7个等级 安全保护等级 安全保护等级有5个等级,从上到下是越来越安全的用户自主其实就是用户自己本身具有的相应的能力 安全防护策略 安全策略是对抗攻击的主要策略安全日志: …

Java Chassis 3:接口维度负载均衡

本文分享自华为云社区《Java Chassis 3技术解密&#xff1a;接口维度负载均衡》&#xff0c;作者&#xff1a; liubao68。 Java Chassis 3技术解密&#xff1a;接口维度负载均衡 在Java Chassis 3技术解密&#xff1a;负载均衡选择器中解密了Java Chassis 3负载均衡在解决性能…

社交媒体数据恢复:事秘达

社交软件-事秘达的数据恢复教程如下&#xff1a; 首先&#xff0c;你需要停止使用事秘达应用&#xff0c;以避免覆盖已经删除的数据。 接着&#xff0c;你需要连接你的手机到电脑上&#xff0c;并打开手机的USB调试功能。这可以让电脑读取你手机的数据。 下载并安装一个数据恢…

重生奇迹MU快速获取经验解析

重生奇迹MU觉醒卡级怎么办快速获取经验攻略&#xff0c;在游戏中卡级是玩家会遇到的情况之一&#xff0c;面对打不过的敌人和过不去的主线&#xff0c;想办法升级才是最主要的&#xff0c;游戏中有很多获取经验的途径。下面让我们一起来了解一下卡级后获取经验的攻略&#xff0…

PatterNodes 3 mac矢量图设计 ,色彩与图案的完美融合!

PatterNodes 3 for Mac是一款功能强大的矢量图形模式创建软件&#xff0c;专为Mac用户设计。它采用基于节点的界面&#xff0c;支持创建形状、线条、曲线或文本&#xff0c;以构建复杂的矢量图形模式。该软件还具备灵活的参数调整功能&#xff0c;允许用户实时预览结果并进行无…

PLCnext用三种方式去编写一个功能块

先前提到的基于eCLR&#xff0c;PLCnext可以通过几种高级语言编写功能块、函数、等等&#xff0c;今天我们来试一下利用IEC61131、C、C#去制作加法功能块。 1.准备工具 PLCnext Engineer & 1152 Simulator PLCnext Engineer是上位编程软件&#xff0c;免费&#xff0c;11…

通俗易懂的策略模式讲解

什么是策略模式&#xff1f; 策略模式是一种设计模式&#xff0c;它允许你定义一系列的算法&#xff08;策略&#xff09;&#xff0c;并将每个算法封装成一个对象。这样&#xff0c;你可以轻松地切换不同的算法&#xff0c;而不需要改变原始代码。 一个简单的例子 假设你是…

ACM实训冲刺第八天

【碎碎念】由于昨天做的题都有思路&#xff0c;加上今天有点疲惫&#xff0c;故今天先不复习了&#xff0c;直接开始今天的算法学习 Tokitsukaze and All Zero Sequence 问题 思路 读入测试用例数&#xff1a;首先读取一个整数t&#xff0c;表示接下来会有t组数据需要处理。遍…

【达梦数据库】搭建 DM->mysql dblink

DM->mysql dblink 1安装mysql odbc rpm -ivh mysql-connector-odbc-5.3.14-1.el7.x86_64.rpm2mysql创建远程用户与远程数据库 mysql> show databases; ------------------------- | Database | ------------------------- | information_schema | …

【Linux系统编程】第十九弹---进程状态(下)

​​​​​​​ ✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】 目录 1、僵尸进程 2、孤儿进程 3、运行状态 4、阻塞状态 5、挂起状态 6、进程切换 总结 1、僵尸进程 上一弹…

OpenHarmony 3GPP协议开发深度剖析——一文读懂RIL

市面上关于终端&#xff08;手机&#xff09;操作系统在 3GPP 协议开发的内容太少了&#xff0c;即使 Android 相关的学习文档都很少&#xff0c;Android 协议开发书籍我是没有见过的。可能是市场需求的缘故吧&#xff0c;现在市场上还是前后端软件开发从业人员最多&#xff0c…

【Open AI】GPT-4o深夜发布:视觉、听觉跨越式升级

北京时间5月14日1点整&#xff0c;OpenAI 召开了首场春季发布会&#xff0c;CTO Mira Murati 在台上和团队用短短不到30分钟的时间&#xff0c;揭开了最新旗舰模型 GPT-4o 的神秘面纱&#xff0c;以及基于 GPT-4o 的 ChatGPT&#xff0c;均为免费使用。 本文内容来自OpenAI网站…

微服务架构:注册中心 Eureka、ZooKeeper、Consul、Nacos的选型对比详解

微服务架构&#xff08;Microservices Architecture&#xff09;是一种基于服务拆分的分布式架构模式&#xff0c;旨在将复杂的单体应用程序拆分为一组更小、更独立的服务单元。这些服务单元可以独立开发、测试、部署&#xff0c;并使用不同的技术栈和编程语言。它们通过轻量级…

外贸业务中的12个“坑”,你踩到了吗?

在竞争激烈的外贸领域&#xff0c;企业在拓展市场的同时&#xff0c;也面临着各种潜在的陷阱和风险。对于外贸公司而言&#xff0c;如何在复杂的交易过程中识破陷阱&#xff0c;防范潜在风险&#xff0c;成为确保企业长远发展的关键一环。 以下是一些外贸企业可能遇到的陷阱&a…

Nebula街机模拟器 Mac移植版(400+游戏roms)汉化版

nebula星云模拟器是电脑上最热门的街机游戏模拟器之一&#xff0c;玩家可以通过这个小巧的模拟器软件进行多款经典街机游戏启动和畅玩&#xff0c;本次移植的包含400多款游戏roms&#xff0c;经典的三国志、三国战纪、拳皇、街霸、合金弹头、1941都包含在内。 下载地址&#xf…

电感式传感器

电感传感器是基于电磁感应原理&#xff0c;将被测非电量&#xff08;如位移、压力、振动等&#xff09;转换为电感量变化的一种结构性传感器。利用自感原理的有自感式传感器&#xff08;可变磁阻式&#xff09;&#xff0c;利用互感原理的有互感式&#xff08;差动变压器式和涡…