Use this skill when implementing deep learning models with PyTorch. Covers model architecture design, custom layers, mixed precision training, distributed training (DDP/FSDP), gradient checkpointing, and checkpointing.
View on GitHubyxbian23/ai-research-claude-code
everything-claude-code
skills/pytorch-patterns/SKILL.md
January 25, 2026
Select agents to install to:
npx add-skill https://github.com/yxbian23/ai-research-claude-code/blob/main/skills/pytorch-patterns/SKILL.md -a claude-code --skill pytorch-patternsInstallation paths:
.claude/skills/pytorch-patterns/# PyTorch Development Patterns
This skill provides comprehensive guidance for professional PyTorch development, from model design to distributed training.
## When to Activate
- Implementing new neural network architectures
- Setting up training pipelines
- Optimizing training efficiency
- Debugging model issues
- Implementing custom layers or loss functions
## Model Architecture Patterns
### Base Model Template
```python
import torch
import torch.nn as nn
from typing import Optional, Dict, Any
class BaseModel(nn.Module):
"""Base class for all models with common functionality."""
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.config = config
self._build_model()
def _build_model(self):
"""Override in subclass to build architecture."""
raise NotImplementedError
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass - override in subclass."""
raise NotImplementedError
def get_num_params(self, non_embedding: bool = True) -> int:
"""Count parameters."""
n_params = sum(p.numel() for p in self.parameters())
if non_embedding and hasattr(self, 'embedding'):
n_params -= self.embedding.weight.numel()
return n_params
@classmethod
def from_pretrained(cls, path: str, **kwargs):
"""Load pretrained model."""
checkpoint = torch.load(path, map_location='cpu')
config = checkpoint['config']
model = cls(config, **kwargs)
model.load_state_dict(checkpoint['model'])
return model
def save_pretrained(self, path: str):
"""Save model checkpoint."""
torch.save({
'config': self.config,
'model': self.state_dict(),
}, path)
```
### Transformer Block Pattern
```python
class TransformerBlock(nn.Module):
"""Standard transformer block with pre-norm."""
def __init__(
self,
dim: int,
num_heads: int,