当前位置:CRM > 互联网资讯 > PyTorch 该怎么学?太简单了

PyTorch 该怎么学?太简单了

2024-03-07 18:02:16互联网资讯
2024-03-07,

挺多小伙伴问过PyTorch该怎么学,经过长期实践来看,初学者需要熟知的概念和用法真的不多,以下总结的简明指南一起看看吧!

构建Tensor

PyTorch 中的 Tensors 是多维数组,类似于 NumPy 的 ndarrays,但可以在 GPU 上运行:

import torch # Create a 2x3 tensor tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) print(tensor) 动态计算图

PyTorch 使用动态计算图,在执行操作时即时构建计算图,这为在运行时修改图形提供了灵活性:

# Define two tensors a = torch.tensor([2.], requires_grad=True) b = torch.tensor([3.], requires_grad=True) # Compute result c = a * b c.backward() # Gradients print(a.grad) # Gradient w.r.t a GPU加速

PyTorch 允许在 CPU 和 GPU 之间轻松切换。使用 .to(device) 即可:

device = "cuda" if torch.cuda.is_available() else "cpu" tensor = tensor.to(device) Autograd:自动微分

PyTorch 的 autograd 为tensor的所有运算提供了自动微分功能,设置 requires_grad=True可以跟踪计算:

x = torch.tensor([2.], requires_grad=True) y = x**2 y.backward() print(x.grad) # Gradient of y w.r.t x 模块化神经网络

PyTorch 提供了 nn.Module 类来定义神经网络架构,通过子类化创建自定义层:

import torch.nn as nn class SimpleNN(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(1, 1) def forward(self, x): return self.fc(x) 预定义层和损失函数

PyTorch 在 nn 模块中提供了各种预定义层、损失函数和优化算法:

loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) Dataset 与 DataLoader

为实现高效的数据处理和批处理,PyTorch 提供了 Dataset 和 DataLoader 类:

from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): # ... (methods to define) data_loader = DataLoader(dataset, batch_size=32, shuffle=True) 模型训练(循环)

通常PyTorch 的训练遵循以下模式:前向传播、计算损失、反向传递和参数更新:

for epoch in range(epochs): for data, target in data_loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() 模型序列化

使用 torch.save() 和 torch.load() 保存并加载模型:

# Save torch.save(model.state_dict(), 'model_weights.pth') # Load model.load_state_dict(torch.load('model_weights.pth')) JIT

PyTorch 默认以eager模式运行,但也为模型提供即时(JIT)编译:

scripted_model = torch.jit.script(model) scripted_model.save("model_jit.pt")PS:本文来源:PyTorch 该怎么学?太简单了,PyTorch,人工智能,作者:啥都生

版权声明:本文由CRM小助手整理收集与网络,仅供学习交流使用,不代表CRM论坛观点。如有侵权,请联系我们,我们将及时删除处理。

CRM论坛投稿:投稿地址


  CRM论坛(CRMBBS.COM)始办于2019年,是致力于CRM实施方案、免费CRM软件、SCRM系统、客户管理系统的垂直内容社区网站,CRM论坛持续专注于CRM领域,在不断深化理解CRM系统的同时,进一步利用新型互联网技术,为用户实现企业、客户、合作伙伴与产品之间的无缝连接与交互。

标签: PyTorch