Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
eeef97b
0
Parent(s):
Deploy to HF spaces
Browse files- .gitattributes +2 -0
- .gitignore +77 -0
- README.md +124 -0
- app.py +472 -0
- load/tets/160_tets.npz +3 -0
- requirements.txt +40 -0
- sf3d/models/camera.py +32 -0
- sf3d/models/global_estimator/multi_head_estimator.py +118 -0
- sf3d/models/image_estimator/clip_based_estimator.py +168 -0
- sf3d/models/isosurface.py +229 -0
- sf3d/models/mesh.py +289 -0
- sf3d/models/network.py +213 -0
- sf3d/models/tokenizers/dinov2.py +1196 -0
- sf3d/models/tokenizers/image.py +101 -0
- sf3d/models/tokenizers/triplane.py +49 -0
- sf3d/models/transformers/attention.py +31 -0
- sf3d/models/transformers/backbone.py +515 -0
- sf3d/models/utils.py +236 -0
- sf3d/system.py +534 -0
- sf3d/utils.py +105 -0
- texture_baker/README.md +26 -0
- texture_baker/requirements.txt +2 -0
- texture_baker/setup.py +142 -0
- texture_baker/texture_baker/__init__.py +4 -0
- texture_baker/texture_baker/baker.py +86 -0
- texture_baker/texture_baker/csrc/baker.cpp +548 -0
- texture_baker/texture_baker/csrc/baker.h +203 -0
- texture_baker/texture_baker/csrc/baker_kernel.cu +306 -0
- texture_baker/texture_baker/csrc/baker_kernel.metal +170 -0
- texture_baker/texture_baker/csrc/baker_kernel.mm +260 -0
- uv_unwrapper/README.md +0 -0
- uv_unwrapper/requirements.txt +2 -0
- uv_unwrapper/setup.py +83 -0
- uv_unwrapper/uv_unwrapper/__init__.py +6 -0
- uv_unwrapper/uv_unwrapper/csrc/bvh.cpp +381 -0
- uv_unwrapper/uv_unwrapper/csrc/bvh.h +118 -0
- uv_unwrapper/uv_unwrapper/csrc/common.h +493 -0
- uv_unwrapper/uv_unwrapper/csrc/intersect.cpp +702 -0
- uv_unwrapper/uv_unwrapper/csrc/intersect.h +10 -0
- uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp +271 -0
- uv_unwrapper/uv_unwrapper/unwrap.py +669 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
load/tets/160_tets.npz filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
ENV/
|
| 27 |
+
.venv
|
| 28 |
+
|
| 29 |
+
# IDE
|
| 30 |
+
.vscode/
|
| 31 |
+
.idea/
|
| 32 |
+
*.swp
|
| 33 |
+
*.swo
|
| 34 |
+
*~
|
| 35 |
+
|
| 36 |
+
# Jupyter Notebook
|
| 37 |
+
.ipynb_checkpoints
|
| 38 |
+
|
| 39 |
+
# Environment variables
|
| 40 |
+
.env
|
| 41 |
+
.env.local
|
| 42 |
+
|
| 43 |
+
# Model cache
|
| 44 |
+
.cache/
|
| 45 |
+
*.safetensors
|
| 46 |
+
*.ckpt
|
| 47 |
+
*.pt
|
| 48 |
+
*.pth
|
| 49 |
+
|
| 50 |
+
# Generated files
|
| 51 |
+
output/
|
| 52 |
+
*.glb
|
| 53 |
+
*.gltf
|
| 54 |
+
*.obj
|
| 55 |
+
*.ply
|
| 56 |
+
|
| 57 |
+
# Gradio temp files
|
| 58 |
+
gradio_cached_examples/
|
| 59 |
+
flagged/
|
| 60 |
+
|
| 61 |
+
# OS
|
| 62 |
+
.DS_Store
|
| 63 |
+
Thumbs.db
|
| 64 |
+
|
| 65 |
+
# Logs
|
| 66 |
+
*.log
|
| 67 |
+
logs/
|
| 68 |
+
|
| 69 |
+
# Temporary files
|
| 70 |
+
tmp/
|
| 71 |
+
temp/
|
| 72 |
+
*.tmp
|
| 73 |
+
|
| 74 |
+
# Hugging Face
|
| 75 |
+
.huggingface/
|
| 76 |
+
|
| 77 |
+
references/
|
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Stable Diffusion Fast Text to 3D
|
| 3 |
+
emoji: 🎨
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.1.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: other
|
| 11 |
+
license_name: stabilityai-ai-community
|
| 12 |
+
license_link: https://huggingface.co/stabilityai/stable-fast-3d/blob/main/LICENSE.md
|
| 13 |
+
models:
|
| 14 |
+
- stabilityai/stable-diffusion-xl-base-1.0
|
| 15 |
+
- stabilityai/stable-fast-3d
|
| 16 |
+
gpu: true
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# Text to Image to 3D Generation
|
| 20 |
+
|
| 21 |
+
This Hugging Face Space provides a complete workflow to generate 3D models from text prompts using:
|
| 22 |
+
|
| 23 |
+
1. **Stable Diffusion XL** - Generate high-quality images from text prompts
|
| 24 |
+
2. **rembg** - Remove backgrounds from generated images
|
| 25 |
+
3. **Stable Fast 3D** - Convert images to 3D mesh models
|
| 26 |
+
|
| 27 |
+
## Features
|
| 28 |
+
|
| 29 |
+
- 🎨 **Text to Image**: Generate images using Stable Diffusion XL base model
|
| 30 |
+
- ✂️ **Background Removal**: Automatically remove backgrounds using rembg
|
| 31 |
+
- 🎮 **3D Generation**: Create textured 3D mesh models from images
|
| 32 |
+
- 🔄 **Step-by-step Workflow**: Review and confirm at each step
|
| 33 |
+
- ⚙️ **Customizable**: Adjust remeshing options, vertex count, and texture resolution
|
| 34 |
+
|
| 35 |
+
## How to Use
|
| 36 |
+
|
| 37 |
+
1. **Step 1 - Text to Image**:
|
| 38 |
+
- Enter your text prompt describing what you want to generate
|
| 39 |
+
- Optionally add a negative prompt to exclude unwanted elements
|
| 40 |
+
- Adjust the number of inference steps (more steps = higher quality, slower)
|
| 41 |
+
- Click "Generate Image" and wait for the result
|
| 42 |
+
|
| 43 |
+
2. **Step 2 - Background Removal**:
|
| 44 |
+
- Review the generated image
|
| 45 |
+
- Click "Continue to Background Removal" to remove the background
|
| 46 |
+
- Preview the result with transparency
|
| 47 |
+
|
| 48 |
+
3. **Step 3 - 3D Generation**:
|
| 49 |
+
- Review the background-removed image
|
| 50 |
+
- Adjust 3D generation settings:
|
| 51 |
+
- **Remeshing Option**: Choose "none", "triangle", or "quad" remeshing
|
| 52 |
+
- **Target Vertex Count**: Set to -1 for automatic, or specify a target count
|
| 53 |
+
- **Texture Size**: Choose texture resolution (512-2048)
|
| 54 |
+
- Click "Continue to 3D Generation" to create the 3D model
|
| 55 |
+
- Download your GLB file
|
| 56 |
+
|
| 57 |
+
## Tips
|
| 58 |
+
|
| 59 |
+
- **Prompts**: Be descriptive and specific. Include style keywords like "3D render", "character", "stylized"
|
| 60 |
+
- **Background Removal**: Works best with clear foreground objects
|
| 61 |
+
- **3D Generation**:
|
| 62 |
+
- Use "none" remeshing for best quality
|
| 63 |
+
- Higher texture sizes produce better quality but take longer
|
| 64 |
+
- Vertex count of -1 uses the model's default
|
| 65 |
+
|
| 66 |
+
## Technical Details
|
| 67 |
+
|
| 68 |
+
- **Models Used**:
|
| 69 |
+
- `stabilityai/stable-diffusion-xl-base-1.0` for text-to-image
|
| 70 |
+
- `rembg` for background removal
|
| 71 |
+
- `stabilityai/stable-fast-3d` for image-to-3D
|
| 72 |
+
|
| 73 |
+
- **Output Format**: GLB (glTF Binary) files compatible with most 3D software and viewers
|
| 74 |
+
|
| 75 |
+
- **GPU Resource Management**:
|
| 76 |
+
- Uses `@spaces.GPU()` decorators to properly manage GPU resources in Hugging Face Spaces
|
| 77 |
+
- GPU is allocated for text-to-image generation, background removal, and 3D mesh generation
|
| 78 |
+
- Ensures efficient GPU usage across the workflow
|
| 79 |
+
|
| 80 |
+
## Requirements
|
| 81 |
+
|
| 82 |
+
This Space requires:
|
| 83 |
+
- GPU support (recommended for faster generation)
|
| 84 |
+
- Sufficient memory for model loading
|
| 85 |
+
- Internet connection for model downloads
|
| 86 |
+
- **Access to gated models**: The `stabilityai/stable-fast-3d` model is gated. You must:
|
| 87 |
+
1. Accept the model's terms of use on [Hugging Face](https://huggingface.co/stabilityai/stable-fast-3d)
|
| 88 |
+
2. The Space will automatically authenticate using the `HF_TOKEN` environment variable
|
| 89 |
+
|
| 90 |
+
### Dependencies
|
| 91 |
+
|
| 92 |
+
This Space uses several models and packages:
|
| 93 |
+
|
| 94 |
+
1. **Stable Diffusion XL**: Automatically downloaded from Hugging Face
|
| 95 |
+
2. **rembg**: Installed via pip (included in requirements.txt)
|
| 96 |
+
3. **Stable Fast 3D**:
|
| 97 |
+
- The model weights are downloaded from Hugging Face
|
| 98 |
+
- The `sf3d` Python package is included in this repository
|
| 99 |
+
- **Note**: This is a gated model - access must be granted by Stability AI
|
| 100 |
+
4. **texture_baker and uv_unwrapper**:
|
| 101 |
+
- These packages are included in the repository
|
| 102 |
+
- They are automatically compiled and installed at runtime when the app starts
|
| 103 |
+
- Installation may take a few minutes on first run
|
| 104 |
+
- CUDA architecture is automatically detected or uses fallback architectures
|
| 105 |
+
|
| 106 |
+
### Authentication
|
| 107 |
+
|
| 108 |
+
- **Hugging Face Token**: The Space automatically authenticates using the `HF_TOKEN` environment variable
|
| 109 |
+
- **Gated Model Access**: You must accept the terms and request access to `stabilityai/stable-fast-3d` on Hugging Face
|
| 110 |
+
- The authentication happens at startup, so all model downloads use the authenticated session
|
| 111 |
+
|
| 112 |
+
All required packages (`sf3d`, `texture_baker`, `uv_unwrapper`, and `load/tets`) are included in this repository, so no additional setup is needed.
|
| 113 |
+
|
| 114 |
+
## License
|
| 115 |
+
|
| 116 |
+
This Space uses models with the following licenses:
|
| 117 |
+
- Stable Diffusion XL: [CreativeML Open RAIL++-M License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md)
|
| 118 |
+
- Stable Fast 3D: [Stability AI Community License](https://huggingface.co/stabilityai/stable-fast-3d/blob/main/LICENSE.md)
|
| 119 |
+
|
| 120 |
+
## Credits
|
| 121 |
+
|
| 122 |
+
- [Stability AI](https://stability.ai/) for Stable Diffusion XL and Stable Fast 3D
|
| 123 |
+
- [rembg](https://github.com/danielgatis/rembg) for background removal
|
| 124 |
+
- Built with [Gradio](https://gradio.app/)
|
app.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
import time
|
| 6 |
+
from contextlib import nullcontext
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import numpy as np
|
| 12 |
+
import rembg
|
| 13 |
+
from diffusers import DiffusionPipeline
|
| 14 |
+
from gradio_litmodel3d import LitModel3D
|
| 15 |
+
from huggingface_hub import login
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
# Authenticate with Hugging Face using token from environment
|
| 19 |
+
# HF_TOKEN is automatically available in Hugging Face Spaces
|
| 20 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 21 |
+
if hf_token:
|
| 22 |
+
# Login to Hugging Face - this stores the token for all HF Hub operations
|
| 23 |
+
login(token=hf_token)
|
| 24 |
+
# Also ensure it's set as environment variable for any libraries that check it directly
|
| 25 |
+
os.environ["HF_TOKEN"] = hf_token
|
| 26 |
+
print("Authenticated with Hugging Face")
|
| 27 |
+
else:
|
| 28 |
+
print("Warning: HF_TOKEN not found. Gated models may not be accessible.")
|
| 29 |
+
print("Please ensure HF_TOKEN is set in your Space's secrets.")
|
| 30 |
+
|
| 31 |
+
if not torch.cuda.is_available():
|
| 32 |
+
raise Exception("CUDA is not available")
|
| 33 |
+
|
| 34 |
+
# Set environment variables for building texture_baker and uv_unwrapper
|
| 35 |
+
os.environ["USE_CUDA"] = "1"
|
| 36 |
+
os.environ["USE_NATIVE_ARCH"] = "0" # Disable native arch to avoid build issues
|
| 37 |
+
|
| 38 |
+
# Set CUDA architecture list to avoid detection issues
|
| 39 |
+
# PyTorch's build system fails when it can't detect GPU architectures
|
| 40 |
+
# Setting TORCH_CUDA_ARCH_LIST explicitly prevents this error
|
| 41 |
+
if torch.cuda.is_available():
|
| 42 |
+
try:
|
| 43 |
+
# Try to get the actual compute capability
|
| 44 |
+
compute_cap = torch.cuda.get_device_capability(0)
|
| 45 |
+
cuda_arch = f"{compute_cap[0]}.{compute_cap[1]}"
|
| 46 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = cuda_arch
|
| 47 |
+
print(
|
| 48 |
+
f"Detected CUDA capability: {cuda_arch}, setting TORCH_CUDA_ARCH_LIST={cuda_arch}"
|
| 49 |
+
)
|
| 50 |
+
except Exception as e:
|
| 51 |
+
# Fallback to common architectures if detection fails
|
| 52 |
+
# Include multiple architectures to support various GPU models
|
| 53 |
+
fallback_archs = "7.0;7.5;8.0;8.6;8.9;9.0"
|
| 54 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = fallback_archs
|
| 55 |
+
print(
|
| 56 |
+
f"Could not detect CUDA capability: {e}, using fallback architectures: {fallback_archs}"
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
# Should not happen since we check above, but just in case
|
| 60 |
+
print("Warning: CUDA not available but trying to build with CUDA support")
|
| 61 |
+
|
| 62 |
+
os.system(
|
| 63 |
+
"USE_CUDA=1 USE_NATIVE_ARCH=0 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
import sf3d.utils as sf3d_utils
|
| 67 |
+
from sf3d.system import SF3D
|
| 68 |
+
|
| 69 |
+
# Set up environment
|
| 70 |
+
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
|
| 71 |
+
|
| 72 |
+
# Initialize rembg session
|
| 73 |
+
rembg_session = rembg.new_session()
|
| 74 |
+
|
| 75 |
+
# Constants for 3D generation
|
| 76 |
+
COND_WIDTH = 512
|
| 77 |
+
COND_HEIGHT = 512
|
| 78 |
+
COND_DISTANCE = 1.6
|
| 79 |
+
COND_FOVY_DEG = 40
|
| 80 |
+
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
|
| 81 |
+
|
| 82 |
+
# Cached. Doesn't change
|
| 83 |
+
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
|
| 84 |
+
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
|
| 85 |
+
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
generated_files = []
|
| 89 |
+
|
| 90 |
+
# Initialize device and SF3D model (like official app)
|
| 91 |
+
device = sf3d_utils.get_device()
|
| 92 |
+
|
| 93 |
+
# SF3D model - initialized at startup like official app
|
| 94 |
+
# Token is automatically used after login() call above
|
| 95 |
+
sf3d_model = SF3D.from_pretrained(
|
| 96 |
+
"stabilityai/stable-fast-3d",
|
| 97 |
+
config_name="config.yaml",
|
| 98 |
+
weight_name="model.safetensors",
|
| 99 |
+
)
|
| 100 |
+
sf3d_model.eval()
|
| 101 |
+
sf3d_model = sf3d_model.to(device)
|
| 102 |
+
|
| 103 |
+
# SDXL pipeline - lazy loaded to save memory
|
| 104 |
+
sd_pipeline = None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def initialize_sdxl():
|
| 108 |
+
"""Initialize SDXL pipeline on first use."""
|
| 109 |
+
global sd_pipeline, device
|
| 110 |
+
|
| 111 |
+
if sd_pipeline is None:
|
| 112 |
+
print("Loading Stable Diffusion XL model...")
|
| 113 |
+
sd_pipeline = DiffusionPipeline.from_pretrained(
|
| 114 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 115 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 116 |
+
use_safetensors=True,
|
| 117 |
+
variant="fp16" if device == "cuda" else None,
|
| 118 |
+
)
|
| 119 |
+
if device == "cuda":
|
| 120 |
+
sd_pipeline = sd_pipeline.to(device)
|
| 121 |
+
# Enable memory efficient attention if available
|
| 122 |
+
try:
|
| 123 |
+
sd_pipeline.enable_xformers_memory_efficient_attention()
|
| 124 |
+
except:
|
| 125 |
+
pass
|
| 126 |
+
elif device == "mps":
|
| 127 |
+
sd_pipeline = sd_pipeline.to(device)
|
| 128 |
+
else:
|
| 129 |
+
sd_pipeline.enable_model_cpu_offload()
|
| 130 |
+
print("SDXL model loaded!")
|
| 131 |
+
|
| 132 |
+
return sd_pipeline
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@spaces.GPU()
|
| 136 |
+
def generate_text_to_image(
|
| 137 |
+
prompt: str, negative_prompt: str = "", num_inference_steps: int = 30
|
| 138 |
+
):
|
| 139 |
+
"""Generate image from text prompt using SDXL."""
|
| 140 |
+
pipeline = initialize_sdxl()
|
| 141 |
+
|
| 142 |
+
print(f"Generating image from prompt: {prompt}")
|
| 143 |
+
|
| 144 |
+
# Generate image
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
if device == "cuda":
|
| 147 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 148 |
+
image = pipeline(
|
| 149 |
+
prompt=prompt,
|
| 150 |
+
negative_prompt=negative_prompt if negative_prompt else None,
|
| 151 |
+
num_inference_steps=num_inference_steps,
|
| 152 |
+
).images[0]
|
| 153 |
+
else:
|
| 154 |
+
image = pipeline(
|
| 155 |
+
prompt=prompt,
|
| 156 |
+
negative_prompt=negative_prompt if negative_prompt else None,
|
| 157 |
+
num_inference_steps=num_inference_steps,
|
| 158 |
+
).images[0]
|
| 159 |
+
|
| 160 |
+
return image
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@spaces.GPU()
|
| 164 |
+
def remove_background_from_image(image: Image.Image) -> Image.Image:
|
| 165 |
+
"""Remove background from image using rembg."""
|
| 166 |
+
print("Removing background...")
|
| 167 |
+
result = rembg.remove(image, session=rembg_session)
|
| 168 |
+
return result
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def create_batch(input_image: Image) -> dict[str, Any]:
|
| 172 |
+
"""Create batch for SF3D model - matches official app structure."""
|
| 173 |
+
img_cond = (
|
| 174 |
+
torch.from_numpy(
|
| 175 |
+
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
|
| 176 |
+
/ 255.0
|
| 177 |
+
)
|
| 178 |
+
.float()
|
| 179 |
+
.clip(0, 1)
|
| 180 |
+
)
|
| 181 |
+
mask_cond = img_cond[:, :, -1:]
|
| 182 |
+
rgb_cond = torch.lerp(
|
| 183 |
+
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
batch_elem = {
|
| 187 |
+
"rgb_cond": rgb_cond,
|
| 188 |
+
"mask_cond": mask_cond,
|
| 189 |
+
"c2w_cond": c2w_cond.unsqueeze(0),
|
| 190 |
+
"intrinsic_cond": intrinsic.unsqueeze(0),
|
| 191 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
| 192 |
+
}
|
| 193 |
+
# Add batch dim
|
| 194 |
+
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
|
| 195 |
+
return batched
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def run_model(input_image, remesh_option, vertex_count, texture_size):
|
| 199 |
+
"""Run SF3D model - matches official app structure."""
|
| 200 |
+
start = time.time()
|
| 201 |
+
with torch.no_grad():
|
| 202 |
+
with (
|
| 203 |
+
torch.autocast(device_type=device, dtype=torch.bfloat16)
|
| 204 |
+
if "cuda" in device
|
| 205 |
+
else nullcontext()
|
| 206 |
+
):
|
| 207 |
+
model_batch = create_batch(input_image)
|
| 208 |
+
model_batch = {k: v.to(device) for k, v in model_batch.items()}
|
| 209 |
+
trimesh_mesh, _glob_dict = sf3d_model.generate_mesh(
|
| 210 |
+
model_batch, texture_size, remesh_option.lower(), vertex_count
|
| 211 |
+
)
|
| 212 |
+
trimesh_mesh = trimesh_mesh[0]
|
| 213 |
+
|
| 214 |
+
# Create new tmp file
|
| 215 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
| 216 |
+
|
| 217 |
+
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
|
| 218 |
+
generated_files.append(tmp_file.name)
|
| 219 |
+
|
| 220 |
+
print("Generation took:", time.time() - start, "s")
|
| 221 |
+
|
| 222 |
+
return tmp_file.name
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
@spaces.GPU()
|
| 226 |
+
def generate_3d_from_image(
|
| 227 |
+
input_image: Image.Image,
|
| 228 |
+
remesh_option: str = "none",
|
| 229 |
+
vertex_count: int = -1,
|
| 230 |
+
texture_size: int = 1024,
|
| 231 |
+
) -> str:
|
| 232 |
+
"""Generate 3D mesh from image using SF3D."""
|
| 233 |
+
# Resize foreground if needed (like official app)
|
| 234 |
+
foreground_ratio = 0.85
|
| 235 |
+
processed_image = sf3d_utils.resize_foreground(
|
| 236 |
+
input_image, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return run_model(processed_image, remesh_option, vertex_count, texture_size)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@lru_cache
|
| 243 |
+
def checkerboard(squares: int, size: int, min_value: float = 0.5):
|
| 244 |
+
"""Create checkerboard pattern for transparency preview."""
|
| 245 |
+
base = np.zeros((squares, squares)) + min_value
|
| 246 |
+
base[1::2, ::2] = 1
|
| 247 |
+
base[::2, 1::2] = 1
|
| 248 |
+
|
| 249 |
+
repeat_mult = size // squares
|
| 250 |
+
return (
|
| 251 |
+
base.repeat(repeat_mult, axis=0)
|
| 252 |
+
.repeat(repeat_mult, axis=1)[:, :, None]
|
| 253 |
+
.repeat(3, axis=-1)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def show_mask_preview(input_image: Image.Image) -> Image.Image:
|
| 258 |
+
"""Show image with checkerboard background for transparency preview."""
|
| 259 |
+
img_numpy = np.array(input_image)
|
| 260 |
+
alpha = img_numpy[:, :, 3] / 255.0
|
| 261 |
+
chkb = checkerboard(32, 512) * 255
|
| 262 |
+
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
|
| 263 |
+
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# Gradio Interface Functions
|
| 267 |
+
def step1_generate_image(prompt, negative_prompt, num_steps):
|
| 268 |
+
"""Step 1: Generate image from text."""
|
| 269 |
+
if not prompt:
|
| 270 |
+
return None, gr.update(visible=False), "Please enter a prompt"
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
image = generate_text_to_image(prompt, negative_prompt, num_steps)
|
| 274 |
+
return (
|
| 275 |
+
image,
|
| 276 |
+
gr.update(visible=True, value="Continue to Background Removal"),
|
| 277 |
+
"Image generated successfully! Review and continue to remove background.",
|
| 278 |
+
)
|
| 279 |
+
except Exception as e:
|
| 280 |
+
return None, gr.update(visible=False), f"Error generating image: {str(e)}"
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def step2_remove_background(image):
|
| 284 |
+
"""Step 2: Remove background from image."""
|
| 285 |
+
if image is None:
|
| 286 |
+
return None, None, gr.update(visible=False), "Please generate an image first"
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
# Convert to RGB if needed
|
| 290 |
+
if image.mode != "RGB":
|
| 291 |
+
image = image.convert("RGB")
|
| 292 |
+
|
| 293 |
+
bg_removed = remove_background_from_image(image)
|
| 294 |
+
preview = show_mask_preview(bg_removed)
|
| 295 |
+
|
| 296 |
+
return (
|
| 297 |
+
bg_removed,
|
| 298 |
+
preview,
|
| 299 |
+
gr.update(visible=True, value="Continue to 3D Generation"),
|
| 300 |
+
"Background removed successfully! Review and continue to generate 3D model.",
|
| 301 |
+
)
|
| 302 |
+
except Exception as e:
|
| 303 |
+
return (
|
| 304 |
+
None,
|
| 305 |
+
None,
|
| 306 |
+
gr.update(visible=False),
|
| 307 |
+
f"Error removing background: {str(e)}",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def step3_generate_3d(image_with_bg_removed, remesh_option, vertex_count, texture_size):
|
| 312 |
+
"""Step 3: Generate 3D model from image."""
|
| 313 |
+
if image_with_bg_removed is None:
|
| 314 |
+
return gr.update(value=None, visible=False), "Please remove background first"
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
glb_file = generate_3d_from_image(
|
| 318 |
+
image_with_bg_removed, remesh_option, vertex_count, texture_size
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
return (
|
| 322 |
+
gr.update(value=glb_file, visible=True),
|
| 323 |
+
"3D model generated successfully! You can download it below.",
|
| 324 |
+
)
|
| 325 |
+
except Exception as e:
|
| 326 |
+
return (
|
| 327 |
+
gr.update(value=None, visible=False),
|
| 328 |
+
f"Error generating 3D model: {str(e)}",
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Create Gradio Interface
|
| 333 |
+
with gr.Blocks(title="Text to Image to 3D") as demo:
|
| 334 |
+
gr.Markdown(
|
| 335 |
+
"""
|
| 336 |
+
# Text to Image to 3D Generation
|
| 337 |
+
|
| 338 |
+
This app allows you to generate 3D models from text prompts in three steps:
|
| 339 |
+
1. **Text to Image**: Generate an image using Stable Diffusion XL
|
| 340 |
+
2. **Background Removal**: Remove the background using rembg
|
| 341 |
+
3. **3D Generation**: Create a 3D mesh model using Stable Fast 3D
|
| 342 |
+
|
| 343 |
+
**Instructions:**
|
| 344 |
+
- Enter your text prompt and generate an image
|
| 345 |
+
- Review the generated image and continue to remove the background
|
| 346 |
+
- Review the background-removed image and continue to generate the 3D model
|
| 347 |
+
- Download your 3D model as a GLB file
|
| 348 |
+
"""
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
with gr.Row():
|
| 352 |
+
with gr.Column(scale=1):
|
| 353 |
+
gr.Markdown("### Step 1: Text to Image")
|
| 354 |
+
prompt = gr.Textbox(
|
| 355 |
+
label="Prompt",
|
| 356 |
+
placeholder="A cute robot character, 3D render, colorful",
|
| 357 |
+
lines=2,
|
| 358 |
+
)
|
| 359 |
+
negative_prompt = gr.Textbox(
|
| 360 |
+
label="Negative Prompt (optional)",
|
| 361 |
+
placeholder="blurry, low quality, distorted",
|
| 362 |
+
lines=2,
|
| 363 |
+
)
|
| 364 |
+
num_steps = gr.Slider(
|
| 365 |
+
label="Number of Inference Steps",
|
| 366 |
+
minimum=20,
|
| 367 |
+
maximum=50,
|
| 368 |
+
value=30,
|
| 369 |
+
step=5,
|
| 370 |
+
)
|
| 371 |
+
generate_btn = gr.Button("Generate Image", variant="primary")
|
| 372 |
+
step1_status = gr.Textbox(label="Status", interactive=False)
|
| 373 |
+
|
| 374 |
+
step1_image = gr.Image(label="Generated Image", type="pil")
|
| 375 |
+
step1_continue_btn = gr.Button(
|
| 376 |
+
"Continue to Background Removal",
|
| 377 |
+
visible=False,
|
| 378 |
+
variant="secondary",
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
with gr.Column(scale=1):
|
| 382 |
+
gr.Markdown("### Step 2: Background Removal")
|
| 383 |
+
step2_image = gr.Image(label="Image with Background Removed", type="pil")
|
| 384 |
+
step2_preview = gr.Image(
|
| 385 |
+
label="Preview (with transparency)",
|
| 386 |
+
type="pil",
|
| 387 |
+
visible=False,
|
| 388 |
+
)
|
| 389 |
+
step2_status = gr.Textbox(label="Status", interactive=False)
|
| 390 |
+
step2_continue_btn = gr.Button(
|
| 391 |
+
"Continue to 3D Generation",
|
| 392 |
+
visible=False,
|
| 393 |
+
variant="secondary",
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
with gr.Column(scale=1):
|
| 397 |
+
gr.Markdown("### Step 3: 3D Generation")
|
| 398 |
+
remesh_option = gr.Radio(
|
| 399 |
+
choices=["none", "triangle", "quad"],
|
| 400 |
+
label="Remeshing Option",
|
| 401 |
+
value="none",
|
| 402 |
+
)
|
| 403 |
+
vertex_count = gr.Slider(
|
| 404 |
+
label="Target Vertex Count (-1 for auto)",
|
| 405 |
+
minimum=-1,
|
| 406 |
+
maximum=20000,
|
| 407 |
+
value=-1,
|
| 408 |
+
step=100,
|
| 409 |
+
)
|
| 410 |
+
texture_size = gr.Slider(
|
| 411 |
+
label="Texture Size",
|
| 412 |
+
minimum=512,
|
| 413 |
+
maximum=2048,
|
| 414 |
+
value=1024,
|
| 415 |
+
step=256,
|
| 416 |
+
)
|
| 417 |
+
step3_generate_btn = gr.Button("Generate 3D Model", variant="primary")
|
| 418 |
+
step3_status = gr.Textbox(label="Status", interactive=False)
|
| 419 |
+
|
| 420 |
+
step3_output = LitModel3D(
|
| 421 |
+
label="3D Model",
|
| 422 |
+
visible=False,
|
| 423 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# State variables
|
| 427 |
+
step2_image_state = gr.State()
|
| 428 |
+
|
| 429 |
+
# Event handlers
|
| 430 |
+
generate_btn.click(
|
| 431 |
+
fn=step1_generate_image,
|
| 432 |
+
inputs=[prompt, negative_prompt, num_steps],
|
| 433 |
+
outputs=[step1_image, step1_continue_btn, step1_status],
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
step1_continue_btn.click(
|
| 437 |
+
fn=step2_remove_background,
|
| 438 |
+
inputs=[step1_image],
|
| 439 |
+
outputs=[step2_image, step2_preview, step2_continue_btn, step2_status],
|
| 440 |
+
).then(
|
| 441 |
+
fn=lambda img: img,
|
| 442 |
+
inputs=[step2_image],
|
| 443 |
+
outputs=[step2_image_state],
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
step2_continue_btn.click(
|
| 447 |
+
fn=step3_generate_3d,
|
| 448 |
+
inputs=[step2_image_state, remesh_option, vertex_count, texture_size],
|
| 449 |
+
outputs=[step3_output, step3_status],
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Update preview when image changes
|
| 453 |
+
step2_image.change(
|
| 454 |
+
fn=show_mask_preview,
|
| 455 |
+
inputs=[step2_image],
|
| 456 |
+
outputs=[step2_preview],
|
| 457 |
+
).then(
|
| 458 |
+
fn=lambda: gr.update(visible=True),
|
| 459 |
+
outputs=[step2_preview],
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
if __name__ == "__main__":
|
| 464 |
+
# Delete previous gradio temp dir folder (like official app)
|
| 465 |
+
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
| 466 |
+
print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
|
| 467 |
+
import shutil
|
| 468 |
+
|
| 469 |
+
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
|
| 470 |
+
|
| 471 |
+
demo.queue()
|
| 472 |
+
demo.launch(share=False)
|
load/tets/160_tets.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
|
| 3 |
+
size 15408790
|
requirements.txt
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wheel
|
| 2 |
+
setuptools==69.5.1
|
| 3 |
+
|
| 4 |
+
# Core dependencies
|
| 5 |
+
torch==2.5.1
|
| 6 |
+
torchvision==0.20.1
|
| 7 |
+
numpy==1.26.4
|
| 8 |
+
Pillow>=9.5.0
|
| 9 |
+
|
| 10 |
+
# Stable Diffusion XL
|
| 11 |
+
diffusers>=0.21.0
|
| 12 |
+
transformers==4.42.3
|
| 13 |
+
accelerate>=0.20.0
|
| 14 |
+
safetensors>=0.3.0
|
| 15 |
+
invisible-watermark>=0.2.0
|
| 16 |
+
|
| 17 |
+
# Background removal
|
| 18 |
+
rembg[gpu]==2.0.57; sys_platform != 'darwin'
|
| 19 |
+
rembg==2.0.57; sys_platform == 'darwin'
|
| 20 |
+
|
| 21 |
+
# Stable Fast 3D dependencies
|
| 22 |
+
einops==0.7.0
|
| 23 |
+
jaxtyping==0.2.31
|
| 24 |
+
omegaconf==2.3.0
|
| 25 |
+
open_clip_torch==2.24.0
|
| 26 |
+
trimesh==4.4.1
|
| 27 |
+
huggingface-hub>=0.23.2,<1.0
|
| 28 |
+
pynanoinstantmeshes==0.0.3
|
| 29 |
+
gpytoolbox==0.2.0
|
| 30 |
+
|
| 31 |
+
# Gradio and UI
|
| 32 |
+
gradio==4.41.0
|
| 33 |
+
gradio-litmodel3d==0.0.1
|
| 34 |
+
|
| 35 |
+
# Additional utilities
|
| 36 |
+
tqdm>=4.65.0
|
| 37 |
+
|
| 38 |
+
# (HF hack) These are installed at runtime in gradio_app.py
|
| 39 |
+
# ./texture_baker/
|
| 40 |
+
# ./uv_unwrapper/
|
sf3d/models/camera.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from sf3d.models.utils import BaseModule
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LinearCameraEmbedder(BaseModule):
|
| 11 |
+
@dataclass
|
| 12 |
+
class Config(BaseModule.Config):
|
| 13 |
+
in_channels: int = 25
|
| 14 |
+
out_channels: int = 768
|
| 15 |
+
conditions: List[str] = field(default_factory=list)
|
| 16 |
+
|
| 17 |
+
cfg: Config
|
| 18 |
+
|
| 19 |
+
def configure(self) -> None:
|
| 20 |
+
self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
|
| 21 |
+
|
| 22 |
+
def forward(self, **kwargs):
|
| 23 |
+
cond_tensors = []
|
| 24 |
+
for cond_name in self.cfg.conditions:
|
| 25 |
+
assert cond_name in kwargs
|
| 26 |
+
cond = kwargs[cond_name]
|
| 27 |
+
# cond in shape (B, Nv, ...)
|
| 28 |
+
cond_tensors.append(cond.view(*cond.shape[:2], -1))
|
| 29 |
+
cond_tensor = torch.cat(cond_tensors, dim=-1)
|
| 30 |
+
assert cond_tensor.shape[-1] == self.cfg.in_channels
|
| 31 |
+
embedding = self.linear(cond_tensor)
|
| 32 |
+
return embedding
|
sf3d/models/global_estimator/multi_head_estimator.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Any, List, Optional
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from jaxtyping import Float
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from sf3d.models.network import get_activation
|
| 9 |
+
from sf3d.models.utils import BaseModule
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class HeadSpec:
|
| 14 |
+
name: str
|
| 15 |
+
out_channels: int
|
| 16 |
+
n_hidden_layers: int
|
| 17 |
+
output_activation: Optional[str] = None
|
| 18 |
+
output_bias: float = 0.0
|
| 19 |
+
add_to_decoder_features: bool = False
|
| 20 |
+
shape: Optional[list[int]] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MultiHeadEstimator(BaseModule):
|
| 24 |
+
@dataclass
|
| 25 |
+
class Config(BaseModule.Config):
|
| 26 |
+
triplane_features: int = 1024
|
| 27 |
+
|
| 28 |
+
n_layers: int = 2
|
| 29 |
+
hidden_features: int = 512
|
| 30 |
+
activation: str = "relu"
|
| 31 |
+
|
| 32 |
+
pool: str = "max"
|
| 33 |
+
# Literal["mean", "max"] = "mean" # noqa: F821
|
| 34 |
+
|
| 35 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
| 36 |
+
|
| 37 |
+
cfg: Config
|
| 38 |
+
|
| 39 |
+
def configure(self):
|
| 40 |
+
layers = []
|
| 41 |
+
cur_features = self.cfg.triplane_features * 3
|
| 42 |
+
for _ in range(self.cfg.n_layers):
|
| 43 |
+
layers.append(
|
| 44 |
+
nn.Conv2d(
|
| 45 |
+
cur_features,
|
| 46 |
+
self.cfg.hidden_features,
|
| 47 |
+
kernel_size=3,
|
| 48 |
+
padding=0,
|
| 49 |
+
stride=2,
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
layers.append(self.make_activation(self.cfg.activation))
|
| 53 |
+
|
| 54 |
+
cur_features = self.cfg.hidden_features
|
| 55 |
+
|
| 56 |
+
self.layers = nn.Sequential(*layers)
|
| 57 |
+
|
| 58 |
+
assert len(self.cfg.heads) > 0
|
| 59 |
+
heads = {}
|
| 60 |
+
for head in self.cfg.heads:
|
| 61 |
+
head_layers = []
|
| 62 |
+
for i in range(head.n_hidden_layers):
|
| 63 |
+
head_layers += [
|
| 64 |
+
nn.Linear(
|
| 65 |
+
self.cfg.hidden_features,
|
| 66 |
+
self.cfg.hidden_features,
|
| 67 |
+
),
|
| 68 |
+
self.make_activation(self.cfg.activation),
|
| 69 |
+
]
|
| 70 |
+
head_layers += [
|
| 71 |
+
nn.Linear(
|
| 72 |
+
self.cfg.hidden_features,
|
| 73 |
+
head.out_channels,
|
| 74 |
+
),
|
| 75 |
+
]
|
| 76 |
+
heads[head.name] = nn.Sequential(*head_layers)
|
| 77 |
+
self.heads = nn.ModuleDict(heads)
|
| 78 |
+
|
| 79 |
+
def make_activation(self, activation):
|
| 80 |
+
if activation == "relu":
|
| 81 |
+
return nn.ReLU(inplace=True)
|
| 82 |
+
elif activation == "silu":
|
| 83 |
+
return nn.SiLU(inplace=True)
|
| 84 |
+
else:
|
| 85 |
+
raise NotImplementedError
|
| 86 |
+
|
| 87 |
+
def forward(
|
| 88 |
+
self,
|
| 89 |
+
triplane: Float[Tensor, "B 3 F Ht Wt"],
|
| 90 |
+
) -> dict[str, Any]:
|
| 91 |
+
x = self.layers(
|
| 92 |
+
triplane.reshape(
|
| 93 |
+
triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
if self.cfg.pool == "max":
|
| 98 |
+
x = x.amax(dim=[-2, -1])
|
| 99 |
+
elif self.cfg.pool == "mean":
|
| 100 |
+
x = x.mean(dim=[-2, -1])
|
| 101 |
+
else:
|
| 102 |
+
raise NotImplementedError
|
| 103 |
+
|
| 104 |
+
out = {
|
| 105 |
+
("decoder_" if head.add_to_decoder_features else "")
|
| 106 |
+
+ head.name: get_activation(head.output_activation)(
|
| 107 |
+
self.heads[head.name](x) + head.output_bias
|
| 108 |
+
)
|
| 109 |
+
for head in self.cfg.heads
|
| 110 |
+
}
|
| 111 |
+
for head in self.cfg.heads:
|
| 112 |
+
if head.shape:
|
| 113 |
+
head_name = (
|
| 114 |
+
"decoder_" if head.add_to_decoder_features else ""
|
| 115 |
+
) + head.name
|
| 116 |
+
out[head_name] = out[head_name].reshape(*head.shape)
|
| 117 |
+
|
| 118 |
+
return out
|
sf3d/models/image_estimator/clip_based_estimator.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Any, List, Optional
|
| 3 |
+
|
| 4 |
+
import open_clip
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from jaxtyping import Float
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torchvision.transforms import Normalize
|
| 10 |
+
|
| 11 |
+
from sf3d.models.network import get_activation
|
| 12 |
+
from sf3d.models.utils import BaseModule
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class HeadSpec:
|
| 17 |
+
name: str
|
| 18 |
+
out_channels: int
|
| 19 |
+
n_hidden_layers: int
|
| 20 |
+
output_activation: Optional[str] = None
|
| 21 |
+
output_bias: float = 0.0
|
| 22 |
+
add_to_decoder_features: bool = False
|
| 23 |
+
shape: Optional[list[int]] = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ClipBasedHeadEstimator(BaseModule):
|
| 27 |
+
@dataclass
|
| 28 |
+
class Config(BaseModule.Config):
|
| 29 |
+
model: str = "ViT-B-32"
|
| 30 |
+
pretrain: str = "laion2b_s34b_b79k"
|
| 31 |
+
|
| 32 |
+
distribution: str = "beta"
|
| 33 |
+
|
| 34 |
+
# ["mean", "mode", "sample", "sample_mean"]
|
| 35 |
+
distribution_eval: str = "mode"
|
| 36 |
+
|
| 37 |
+
activation: str = "relu"
|
| 38 |
+
hidden_features: int = 512
|
| 39 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
| 40 |
+
|
| 41 |
+
cfg: Config
|
| 42 |
+
|
| 43 |
+
def configure(self):
|
| 44 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
| 45 |
+
self.cfg.model, pretrained=self.cfg.pretrain
|
| 46 |
+
)
|
| 47 |
+
self.model.eval()
|
| 48 |
+
|
| 49 |
+
# Do not add the weights in self.model to the optimizer
|
| 50 |
+
for param in self.model.parameters():
|
| 51 |
+
param.requires_grad = False
|
| 52 |
+
|
| 53 |
+
assert len(self.cfg.heads) > 0
|
| 54 |
+
heads = {}
|
| 55 |
+
for head in self.cfg.heads:
|
| 56 |
+
head_layers = []
|
| 57 |
+
|
| 58 |
+
for i in range(head.n_hidden_layers):
|
| 59 |
+
head_layers += [
|
| 60 |
+
nn.Linear(
|
| 61 |
+
self.cfg.hidden_features,
|
| 62 |
+
self.cfg.hidden_features,
|
| 63 |
+
),
|
| 64 |
+
self.make_activation(self.cfg.activation),
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
head_layers = [nn.Sequential(*head_layers)]
|
| 68 |
+
head_layers += [
|
| 69 |
+
nn.Sequential(
|
| 70 |
+
nn.Linear(
|
| 71 |
+
self.cfg.hidden_features,
|
| 72 |
+
self.cfg.hidden_features,
|
| 73 |
+
),
|
| 74 |
+
self.make_activation(self.cfg.activation),
|
| 75 |
+
nn.Linear(self.cfg.hidden_features, 1),
|
| 76 |
+
)
|
| 77 |
+
for _ in range(2)
|
| 78 |
+
]
|
| 79 |
+
heads[head.name] = nn.ModuleList(head_layers)
|
| 80 |
+
self.heads = nn.ModuleDict(heads)
|
| 81 |
+
|
| 82 |
+
def make_activation(self, activation):
|
| 83 |
+
if activation == "relu":
|
| 84 |
+
return nn.ReLU(inplace=True)
|
| 85 |
+
elif activation == "silu":
|
| 86 |
+
return nn.SiLU(inplace=True)
|
| 87 |
+
else:
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
def forward(
|
| 91 |
+
self,
|
| 92 |
+
cond_image: Float[Tensor, "B 1 H W 3"],
|
| 93 |
+
sample: bool = True,
|
| 94 |
+
) -> dict[str, Any]:
|
| 95 |
+
# Run the model
|
| 96 |
+
# Resize cond_image to 224
|
| 97 |
+
cond_image = nn.functional.interpolate(
|
| 98 |
+
cond_image.flatten(0, 1).permute(0, 3, 1, 2).contiguous(),
|
| 99 |
+
size=(224, 224),
|
| 100 |
+
mode="bilinear",
|
| 101 |
+
align_corners=False,
|
| 102 |
+
)
|
| 103 |
+
cond_image = Normalize(
|
| 104 |
+
mean=open_clip.constants.OPENAI_DATASET_MEAN,
|
| 105 |
+
std=open_clip.constants.OPENAI_DATASET_STD,
|
| 106 |
+
)(cond_image)
|
| 107 |
+
image_features = self.model.encode_image(cond_image)
|
| 108 |
+
|
| 109 |
+
# Run the heads
|
| 110 |
+
outputs = {}
|
| 111 |
+
|
| 112 |
+
for head_dict in self.cfg.heads:
|
| 113 |
+
head_name = head_dict.name
|
| 114 |
+
shared_head, d1_h, d2_h = self.heads[head_name]
|
| 115 |
+
shared_features = shared_head(image_features)
|
| 116 |
+
d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
|
| 117 |
+
if self.cfg.distribution == "normal":
|
| 118 |
+
mean = d1
|
| 119 |
+
var = d2
|
| 120 |
+
if mean.shape[-1] == 1:
|
| 121 |
+
outputs[head_name] = torch.distributions.Normal(
|
| 122 |
+
mean + head_dict.output_bias,
|
| 123 |
+
torch.nn.functional.softplus(var),
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
outputs[head_name] = torch.distributions.MultivariateNormal(
|
| 127 |
+
mean + head_dict.output_bias,
|
| 128 |
+
torch.nn.functional.softplus(var).diag_embed(),
|
| 129 |
+
)
|
| 130 |
+
elif self.cfg.distribution == "beta":
|
| 131 |
+
outputs[head_name] = torch.distributions.Beta(
|
| 132 |
+
torch.nn.functional.softplus(d1 + head_dict.output_bias),
|
| 133 |
+
torch.nn.functional.softplus(d2 + head_dict.output_bias),
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
raise NotImplementedError
|
| 137 |
+
|
| 138 |
+
if sample:
|
| 139 |
+
for head_dict in self.cfg.heads:
|
| 140 |
+
head_name = head_dict.name
|
| 141 |
+
dist = outputs[head_name]
|
| 142 |
+
|
| 143 |
+
if self.cfg.distribution_eval == "mean":
|
| 144 |
+
out = dist.mean
|
| 145 |
+
elif self.cfg.distribution_eval == "mode":
|
| 146 |
+
out = dist.mode
|
| 147 |
+
elif self.cfg.distribution_eval == "sample_mean":
|
| 148 |
+
out = dist.sample([10]).mean(-1)
|
| 149 |
+
else:
|
| 150 |
+
# use rsample if gradient is needed
|
| 151 |
+
out = dist.rsample() if self.training else dist.sample()
|
| 152 |
+
|
| 153 |
+
outputs[head_name] = get_activation(head_dict.output_activation)(out)
|
| 154 |
+
outputs[f"{head_name}_dist"] = dist
|
| 155 |
+
|
| 156 |
+
for head in self.cfg.heads:
|
| 157 |
+
if head.shape:
|
| 158 |
+
if not sample:
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"Cannot reshape non-sampled probabilisitic outputs"
|
| 161 |
+
)
|
| 162 |
+
outputs[head.name] = outputs[head.name].reshape(*head.shape)
|
| 163 |
+
|
| 164 |
+
if head.add_to_decoder_features:
|
| 165 |
+
outputs[f"decoder_{head.name}"] = outputs[head.name]
|
| 166 |
+
del outputs[head.name]
|
| 167 |
+
|
| 168 |
+
return outputs
|
sf3d/models/isosurface.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from jaxtyping import Float, Integer
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
from .mesh import Mesh
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class IsosurfaceHelper(nn.Module):
|
| 13 |
+
points_range: Tuple[float, float] = (0, 1)
|
| 14 |
+
|
| 15 |
+
@property
|
| 16 |
+
def grid_vertices(self) -> Float[Tensor, "N 3"]:
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def requires_instance_per_batch(self) -> bool:
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MarchingTetrahedraHelper(IsosurfaceHelper):
|
| 25 |
+
def __init__(self, resolution: int, tets_path: str):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.resolution = resolution
|
| 28 |
+
self.tets_path = tets_path
|
| 29 |
+
|
| 30 |
+
self.triangle_table: Float[Tensor, "..."]
|
| 31 |
+
self.register_buffer(
|
| 32 |
+
"triangle_table",
|
| 33 |
+
torch.as_tensor(
|
| 34 |
+
[
|
| 35 |
+
[-1, -1, -1, -1, -1, -1],
|
| 36 |
+
[1, 0, 2, -1, -1, -1],
|
| 37 |
+
[4, 0, 3, -1, -1, -1],
|
| 38 |
+
[1, 4, 2, 1, 3, 4],
|
| 39 |
+
[3, 1, 5, -1, -1, -1],
|
| 40 |
+
[2, 3, 0, 2, 5, 3],
|
| 41 |
+
[1, 4, 0, 1, 5, 4],
|
| 42 |
+
[4, 2, 5, -1, -1, -1],
|
| 43 |
+
[4, 5, 2, -1, -1, -1],
|
| 44 |
+
[4, 1, 0, 4, 5, 1],
|
| 45 |
+
[3, 2, 0, 3, 5, 2],
|
| 46 |
+
[1, 3, 5, -1, -1, -1],
|
| 47 |
+
[4, 1, 2, 4, 3, 1],
|
| 48 |
+
[3, 0, 4, -1, -1, -1],
|
| 49 |
+
[2, 0, 1, -1, -1, -1],
|
| 50 |
+
[-1, -1, -1, -1, -1, -1],
|
| 51 |
+
],
|
| 52 |
+
dtype=torch.long,
|
| 53 |
+
),
|
| 54 |
+
persistent=False,
|
| 55 |
+
)
|
| 56 |
+
self.num_triangles_table: Integer[Tensor, "..."]
|
| 57 |
+
self.register_buffer(
|
| 58 |
+
"num_triangles_table",
|
| 59 |
+
torch.as_tensor(
|
| 60 |
+
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
|
| 61 |
+
),
|
| 62 |
+
persistent=False,
|
| 63 |
+
)
|
| 64 |
+
self.base_tet_edges: Integer[Tensor, "..."]
|
| 65 |
+
self.register_buffer(
|
| 66 |
+
"base_tet_edges",
|
| 67 |
+
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
|
| 68 |
+
persistent=False,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
tets = np.load(self.tets_path)
|
| 72 |
+
self._grid_vertices: Float[Tensor, "..."]
|
| 73 |
+
self.register_buffer(
|
| 74 |
+
"_grid_vertices",
|
| 75 |
+
torch.from_numpy(tets["vertices"]).float(),
|
| 76 |
+
persistent=False,
|
| 77 |
+
)
|
| 78 |
+
self.indices: Integer[Tensor, "..."]
|
| 79 |
+
self.register_buffer(
|
| 80 |
+
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
| 84 |
+
|
| 85 |
+
center_indices, boundary_indices = self.get_center_boundary_index(
|
| 86 |
+
self._grid_vertices
|
| 87 |
+
)
|
| 88 |
+
self.center_indices: Integer[Tensor, "..."]
|
| 89 |
+
self.register_buffer("center_indices", center_indices, persistent=False)
|
| 90 |
+
self.boundary_indices: Integer[Tensor, "..."]
|
| 91 |
+
self.register_buffer("boundary_indices", boundary_indices, persistent=False)
|
| 92 |
+
|
| 93 |
+
def get_center_boundary_index(self, verts):
|
| 94 |
+
magn = torch.sum(verts**2, dim=-1)
|
| 95 |
+
|
| 96 |
+
center_idx = torch.argmin(magn)
|
| 97 |
+
boundary_neg = verts == verts.max()
|
| 98 |
+
boundary_pos = verts == verts.min()
|
| 99 |
+
|
| 100 |
+
boundary = torch.bitwise_or(boundary_pos, boundary_neg)
|
| 101 |
+
boundary = torch.sum(boundary.float(), dim=-1)
|
| 102 |
+
|
| 103 |
+
boundary_idx = torch.nonzero(boundary)
|
| 104 |
+
return center_idx, boundary_idx.squeeze(dim=-1)
|
| 105 |
+
|
| 106 |
+
def normalize_grid_deformation(
|
| 107 |
+
self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
|
| 108 |
+
) -> Float[Tensor, "Nv 3"]:
|
| 109 |
+
return (
|
| 110 |
+
(self.points_range[1] - self.points_range[0])
|
| 111 |
+
/ self.resolution # half tet size is approximately 1 / self.resolution
|
| 112 |
+
* torch.tanh(grid_vertex_offsets)
|
| 113 |
+
) # FIXME: hard-coded activation
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
|
| 117 |
+
return self._grid_vertices
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def all_edges(self) -> Integer[Tensor, "Ne 2"]:
|
| 121 |
+
if self._all_edges is None:
|
| 122 |
+
# compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
|
| 123 |
+
edges = torch.tensor(
|
| 124 |
+
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
|
| 125 |
+
dtype=torch.long,
|
| 126 |
+
device=self.indices.device,
|
| 127 |
+
)
|
| 128 |
+
_all_edges = self.indices[:, edges].reshape(-1, 2)
|
| 129 |
+
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
|
| 130 |
+
_all_edges = torch.unique(_all_edges_sorted, dim=0)
|
| 131 |
+
self._all_edges = _all_edges
|
| 132 |
+
return self._all_edges
|
| 133 |
+
|
| 134 |
+
def sort_edges(self, edges_ex2):
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
|
| 137 |
+
order = order.unsqueeze(dim=1)
|
| 138 |
+
|
| 139 |
+
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
| 140 |
+
b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
|
| 141 |
+
|
| 142 |
+
return torch.stack([a, b], -1)
|
| 143 |
+
|
| 144 |
+
def _forward(self, pos_nx3, sdf_n, tet_fx4):
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
occ_n = sdf_n > 0
|
| 147 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
| 148 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
| 149 |
+
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
| 150 |
+
occ_sum = occ_sum[valid_tets]
|
| 151 |
+
|
| 152 |
+
# find all vertices
|
| 153 |
+
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
|
| 154 |
+
all_edges = self.sort_edges(all_edges)
|
| 155 |
+
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
| 156 |
+
|
| 157 |
+
unique_edges = unique_edges.long()
|
| 158 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
| 159 |
+
mapping = (
|
| 160 |
+
torch.ones(
|
| 161 |
+
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
|
| 162 |
+
)
|
| 163 |
+
* -1
|
| 164 |
+
)
|
| 165 |
+
mapping[mask_edges] = torch.arange(
|
| 166 |
+
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
|
| 167 |
+
)
|
| 168 |
+
idx_map = mapping[idx_map] # map edges to verts
|
| 169 |
+
|
| 170 |
+
interp_v = unique_edges[mask_edges]
|
| 171 |
+
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
| 172 |
+
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
| 173 |
+
edges_to_interp_sdf[:, -1] *= -1
|
| 174 |
+
|
| 175 |
+
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
| 176 |
+
|
| 177 |
+
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
| 178 |
+
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
| 179 |
+
|
| 180 |
+
idx_map = idx_map.reshape(-1, 6)
|
| 181 |
+
|
| 182 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
|
| 183 |
+
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
| 184 |
+
num_triangles = self.num_triangles_table[tetindex]
|
| 185 |
+
|
| 186 |
+
# Generate triangle indices
|
| 187 |
+
faces = torch.cat(
|
| 188 |
+
(
|
| 189 |
+
torch.gather(
|
| 190 |
+
input=idx_map[num_triangles == 1],
|
| 191 |
+
dim=1,
|
| 192 |
+
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
|
| 193 |
+
).reshape(-1, 3),
|
| 194 |
+
torch.gather(
|
| 195 |
+
input=idx_map[num_triangles == 2],
|
| 196 |
+
dim=1,
|
| 197 |
+
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
|
| 198 |
+
).reshape(-1, 3),
|
| 199 |
+
),
|
| 200 |
+
dim=0,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
return verts, faces
|
| 204 |
+
|
| 205 |
+
def forward(
|
| 206 |
+
self,
|
| 207 |
+
level: Float[Tensor, "N3 1"],
|
| 208 |
+
deformation: Optional[Float[Tensor, "N3 3"]] = None,
|
| 209 |
+
) -> Mesh:
|
| 210 |
+
if deformation is not None:
|
| 211 |
+
grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
|
| 212 |
+
deformation
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
grid_vertices = self.grid_vertices
|
| 216 |
+
|
| 217 |
+
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
|
| 218 |
+
|
| 219 |
+
mesh = Mesh(
|
| 220 |
+
v_pos=v_pos,
|
| 221 |
+
t_pos_idx=t_pos_idx,
|
| 222 |
+
# extras
|
| 223 |
+
grid_vertices=grid_vertices,
|
| 224 |
+
tet_edges=self.all_edges,
|
| 225 |
+
grid_level=level,
|
| 226 |
+
grid_deformation=deformation,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
return mesh
|
sf3d/models/mesh.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
|
| 6 |
+
import gpytoolbox
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pynanoinstantmeshes
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import trimesh
|
| 12 |
+
from jaxtyping import Float, Integer
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
from sf3d.models.utils import dot
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from uv_unwrapper import Unwrapper
|
| 19 |
+
except ImportError:
|
| 20 |
+
import logging
|
| 21 |
+
|
| 22 |
+
logging.warning(
|
| 23 |
+
"Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
|
| 24 |
+
)
|
| 25 |
+
# Exit early to avoid further errors
|
| 26 |
+
raise ImportError("uv_unwrapper not found")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Mesh:
|
| 30 |
+
def __init__(
|
| 31 |
+
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
|
| 32 |
+
) -> None:
|
| 33 |
+
self.v_pos: Float[Tensor, "Nv 3"] = v_pos
|
| 34 |
+
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
|
| 35 |
+
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
|
| 36 |
+
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
|
| 37 |
+
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
|
| 38 |
+
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
| 39 |
+
self.extras: Dict[str, Any] = {}
|
| 40 |
+
for k, v in kwargs.items():
|
| 41 |
+
self.add_extra(k, v)
|
| 42 |
+
|
| 43 |
+
self.unwrapper = Unwrapper()
|
| 44 |
+
|
| 45 |
+
def add_extra(self, k, v) -> None:
|
| 46 |
+
self.extras[k] = v
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def requires_grad(self):
|
| 50 |
+
return self.v_pos.requires_grad
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def v_nrm(self):
|
| 54 |
+
if self._v_nrm is None:
|
| 55 |
+
self._v_nrm = self._compute_vertex_normal()
|
| 56 |
+
return self._v_nrm
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def v_tng(self):
|
| 60 |
+
if self._v_tng is None:
|
| 61 |
+
self._v_tng = self._compute_vertex_tangent()
|
| 62 |
+
return self._v_tng
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def v_tex(self):
|
| 66 |
+
if self._v_tex is None:
|
| 67 |
+
self.unwrap_uv()
|
| 68 |
+
return self._v_tex
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def edges(self):
|
| 72 |
+
if self._edges is None:
|
| 73 |
+
self._edges = self._compute_edges()
|
| 74 |
+
return self._edges
|
| 75 |
+
|
| 76 |
+
def _compute_vertex_normal(self):
|
| 77 |
+
i0 = self.t_pos_idx[:, 0]
|
| 78 |
+
i1 = self.t_pos_idx[:, 1]
|
| 79 |
+
i2 = self.t_pos_idx[:, 2]
|
| 80 |
+
|
| 81 |
+
v0 = self.v_pos[i0, :]
|
| 82 |
+
v1 = self.v_pos[i1, :]
|
| 83 |
+
v2 = self.v_pos[i2, :]
|
| 84 |
+
|
| 85 |
+
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
| 86 |
+
|
| 87 |
+
# Splat face normals to vertices
|
| 88 |
+
v_nrm = torch.zeros_like(self.v_pos)
|
| 89 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
| 90 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
| 91 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
| 92 |
+
|
| 93 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
| 94 |
+
v_nrm = torch.where(
|
| 95 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
| 96 |
+
)
|
| 97 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
| 98 |
+
|
| 99 |
+
if torch.is_anomaly_enabled():
|
| 100 |
+
assert torch.all(torch.isfinite(v_nrm))
|
| 101 |
+
|
| 102 |
+
return v_nrm
|
| 103 |
+
|
| 104 |
+
def _compute_vertex_tangent(self):
|
| 105 |
+
vn_idx = [None] * 3
|
| 106 |
+
pos = [None] * 3
|
| 107 |
+
tex = [None] * 3
|
| 108 |
+
for i in range(0, 3):
|
| 109 |
+
pos[i] = self.v_pos[self.t_pos_idx[:, i]]
|
| 110 |
+
tex[i] = self.v_tex[self.t_pos_idx[:, i]]
|
| 111 |
+
# t_nrm_idx is always the same as t_pos_idx
|
| 112 |
+
vn_idx[i] = self.t_pos_idx[:, i]
|
| 113 |
+
|
| 114 |
+
tangents = torch.zeros_like(self.v_nrm)
|
| 115 |
+
tansum = torch.zeros_like(self.v_nrm)
|
| 116 |
+
|
| 117 |
+
# Compute tangent space for each triangle
|
| 118 |
+
duv1 = tex[1] - tex[0]
|
| 119 |
+
duv2 = tex[2] - tex[0]
|
| 120 |
+
dpos1 = pos[1] - pos[0]
|
| 121 |
+
dpos2 = pos[2] - pos[0]
|
| 122 |
+
|
| 123 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
| 124 |
+
|
| 125 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
| 126 |
+
|
| 127 |
+
# Avoid division by zero for degenerated texture coordinates
|
| 128 |
+
denom_safe = denom.clip(1e-6)
|
| 129 |
+
tang = tng_nom / denom_safe
|
| 130 |
+
|
| 131 |
+
# Update all 3 vertices
|
| 132 |
+
for i in range(0, 3):
|
| 133 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
| 134 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
| 135 |
+
tansum.scatter_add_(
|
| 136 |
+
0, idx, torch.ones_like(tang)
|
| 137 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
| 138 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
| 139 |
+
# triangles influence the tangent space more
|
| 140 |
+
tangents = tangents / tansum
|
| 141 |
+
|
| 142 |
+
# Normalize and make sure tangent is perpendicular to normal
|
| 143 |
+
tangents = F.normalize(tangents, dim=1)
|
| 144 |
+
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
|
| 145 |
+
|
| 146 |
+
if torch.is_anomaly_enabled():
|
| 147 |
+
assert torch.all(torch.isfinite(tangents))
|
| 148 |
+
|
| 149 |
+
return tangents
|
| 150 |
+
|
| 151 |
+
def quad_remesh(
|
| 152 |
+
self,
|
| 153 |
+
quad_vertex_count: int = -1,
|
| 154 |
+
quad_rosy: int = 4,
|
| 155 |
+
quad_crease_angle: float = -1.0,
|
| 156 |
+
quad_smooth_iter: int = 2,
|
| 157 |
+
quad_align_to_boundaries: bool = False,
|
| 158 |
+
) -> Mesh:
|
| 159 |
+
if quad_vertex_count < 0:
|
| 160 |
+
quad_vertex_count = self.v_pos.shape[0]
|
| 161 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
|
| 162 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
|
| 163 |
+
|
| 164 |
+
new_vert, new_faces = pynanoinstantmeshes.remesh(
|
| 165 |
+
v_pos,
|
| 166 |
+
t_pos_idx,
|
| 167 |
+
quad_vertex_count // 4,
|
| 168 |
+
rosy=quad_rosy,
|
| 169 |
+
posy=4,
|
| 170 |
+
creaseAngle=quad_crease_angle,
|
| 171 |
+
align_to_boundaries=quad_align_to_boundaries,
|
| 172 |
+
smooth_iter=quad_smooth_iter,
|
| 173 |
+
deterministic=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Briefly load in trimesh
|
| 177 |
+
mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
|
| 178 |
+
|
| 179 |
+
v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
|
| 180 |
+
t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
|
| 181 |
+
|
| 182 |
+
# Create new mesh
|
| 183 |
+
return Mesh(v_pos, t_pos_idx)
|
| 184 |
+
|
| 185 |
+
def triangle_remesh(
|
| 186 |
+
self,
|
| 187 |
+
triangle_average_edge_length_multiplier: Optional[float] = None,
|
| 188 |
+
triangle_remesh_steps: int = 10,
|
| 189 |
+
triangle_vertex_count=-1,
|
| 190 |
+
):
|
| 191 |
+
if triangle_vertex_count > 0:
|
| 192 |
+
reduction = triangle_vertex_count / self.v_pos.shape[0]
|
| 193 |
+
print("Triangle reduction:", reduction)
|
| 194 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
|
| 195 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
|
| 196 |
+
if reduction > 1.0:
|
| 197 |
+
subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
|
| 198 |
+
print("Subdivide iters:", subdivide_iters)
|
| 199 |
+
v_pos, t_pos_idx = gpytoolbox.subdivide(
|
| 200 |
+
v_pos,
|
| 201 |
+
t_pos_idx,
|
| 202 |
+
iters=subdivide_iters,
|
| 203 |
+
)
|
| 204 |
+
reduction = triangle_vertex_count / v_pos.shape[0]
|
| 205 |
+
|
| 206 |
+
# Simplify
|
| 207 |
+
points_out, faces_out, _, _ = gpytoolbox.decimate(
|
| 208 |
+
v_pos,
|
| 209 |
+
t_pos_idx,
|
| 210 |
+
face_ratio=reduction,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Convert back to torch
|
| 214 |
+
self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
|
| 215 |
+
self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
|
| 216 |
+
self._edges = None
|
| 217 |
+
triangle_average_edge_length_multiplier = None
|
| 218 |
+
|
| 219 |
+
edges = self.edges
|
| 220 |
+
if triangle_average_edge_length_multiplier is None:
|
| 221 |
+
h = None
|
| 222 |
+
else:
|
| 223 |
+
h = float(
|
| 224 |
+
torch.linalg.norm(
|
| 225 |
+
self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
|
| 226 |
+
)
|
| 227 |
+
.mean()
|
| 228 |
+
.item()
|
| 229 |
+
* triangle_average_edge_length_multiplier
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Convert to numpy
|
| 233 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
|
| 234 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
|
| 235 |
+
|
| 236 |
+
# Remesh
|
| 237 |
+
v_remesh, f_remesh = gpytoolbox.remesh_botsch(
|
| 238 |
+
v_pos,
|
| 239 |
+
t_pos_idx,
|
| 240 |
+
triangle_remesh_steps,
|
| 241 |
+
h,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Convert back to torch
|
| 245 |
+
v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
|
| 246 |
+
t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
|
| 247 |
+
|
| 248 |
+
# Create new mesh
|
| 249 |
+
return Mesh(v_pos, t_pos_idx)
|
| 250 |
+
|
| 251 |
+
@torch.no_grad()
|
| 252 |
+
def unwrap_uv(
|
| 253 |
+
self,
|
| 254 |
+
island_padding: float = 0.02,
|
| 255 |
+
) -> Mesh:
|
| 256 |
+
uv, indices = self.unwrapper(
|
| 257 |
+
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Do store per vertex UVs.
|
| 261 |
+
# This means we need to duplicate some vertices at the seams
|
| 262 |
+
individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
|
| 263 |
+
individual_faces = torch.arange(
|
| 264 |
+
individual_vertices.shape[0],
|
| 265 |
+
device=individual_vertices.device,
|
| 266 |
+
dtype=self.t_pos_idx.dtype,
|
| 267 |
+
).reshape(-1, 3)
|
| 268 |
+
uv_flat = uv[indices].reshape((-1, 2))
|
| 269 |
+
# uv_flat[:, 1] = 1 - uv_flat[:, 1]
|
| 270 |
+
|
| 271 |
+
self.v_pos = individual_vertices
|
| 272 |
+
self.t_pos_idx = individual_faces
|
| 273 |
+
self._v_tex = uv_flat
|
| 274 |
+
self._v_nrm = self._compute_vertex_normal()
|
| 275 |
+
self._v_tng = self._compute_vertex_tangent()
|
| 276 |
+
|
| 277 |
+
def _compute_edges(self):
|
| 278 |
+
# Compute edges
|
| 279 |
+
edges = torch.cat(
|
| 280 |
+
[
|
| 281 |
+
self.t_pos_idx[:, [0, 1]],
|
| 282 |
+
self.t_pos_idx[:, [1, 2]],
|
| 283 |
+
self.t_pos_idx[:, [2, 0]],
|
| 284 |
+
],
|
| 285 |
+
dim=0,
|
| 286 |
+
)
|
| 287 |
+
edges = edges.sort()[0]
|
| 288 |
+
edges = torch.unique(edges, dim=0)
|
| 289 |
+
return edges
|
sf3d/models/network.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Callable, List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from jaxtyping import Float
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.amp import custom_bwd, custom_fwd
|
| 11 |
+
from torch.autograd import Function
|
| 12 |
+
|
| 13 |
+
from sf3d.models.utils import BaseModule, normalize
|
| 14 |
+
from sf3d.utils import get_device
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
|
| 18 |
+
def wrapper(fn):
|
| 19 |
+
if condition:
|
| 20 |
+
if len(kwargs) == 0:
|
| 21 |
+
return decorator_with_args
|
| 22 |
+
return decorator_with_args(*args, **kwargs)(fn)
|
| 23 |
+
else:
|
| 24 |
+
return fn
|
| 25 |
+
|
| 26 |
+
return wrapper
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PixelShuffleUpsampleNetwork(BaseModule):
|
| 30 |
+
@dataclass
|
| 31 |
+
class Config(BaseModule.Config):
|
| 32 |
+
in_channels: int = 1024
|
| 33 |
+
out_channels: int = 40
|
| 34 |
+
scale_factor: int = 4
|
| 35 |
+
|
| 36 |
+
conv_layers: int = 4
|
| 37 |
+
conv_kernel_size: int = 3
|
| 38 |
+
|
| 39 |
+
cfg: Config
|
| 40 |
+
|
| 41 |
+
def configure(self) -> None:
|
| 42 |
+
layers = []
|
| 43 |
+
output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
|
| 44 |
+
|
| 45 |
+
in_channels = self.cfg.in_channels
|
| 46 |
+
for i in range(self.cfg.conv_layers):
|
| 47 |
+
cur_out_channels = (
|
| 48 |
+
in_channels if i != self.cfg.conv_layers - 1 else output_channels
|
| 49 |
+
)
|
| 50 |
+
layers.append(
|
| 51 |
+
nn.Conv2d(
|
| 52 |
+
in_channels,
|
| 53 |
+
cur_out_channels,
|
| 54 |
+
self.cfg.conv_kernel_size,
|
| 55 |
+
padding=(self.cfg.conv_kernel_size - 1) // 2,
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
if i != self.cfg.conv_layers - 1:
|
| 59 |
+
layers.append(nn.ReLU(inplace=True))
|
| 60 |
+
|
| 61 |
+
layers.append(nn.PixelShuffle(self.cfg.scale_factor))
|
| 62 |
+
|
| 63 |
+
self.upsample = nn.Sequential(*layers)
|
| 64 |
+
|
| 65 |
+
def forward(
|
| 66 |
+
self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
|
| 67 |
+
) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
|
| 68 |
+
return rearrange(
|
| 69 |
+
self.upsample(
|
| 70 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
| 71 |
+
),
|
| 72 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
| 73 |
+
Np=3,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class _TruncExp(Function): # pylint: disable=abstract-method
|
| 78 |
+
# Implementation from torch-ngp:
|
| 79 |
+
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
| 80 |
+
@staticmethod
|
| 81 |
+
@conditional_decorator(
|
| 82 |
+
custom_fwd,
|
| 83 |
+
"cuda" in get_device(),
|
| 84 |
+
cast_inputs=torch.float32,
|
| 85 |
+
device_type="cuda",
|
| 86 |
+
)
|
| 87 |
+
def forward(ctx, x): # pylint: disable=arguments-differ
|
| 88 |
+
ctx.save_for_backward(x)
|
| 89 |
+
return torch.exp(x)
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
@conditional_decorator(custom_bwd, "cuda" in get_device())
|
| 93 |
+
def backward(ctx, g): # pylint: disable=arguments-differ
|
| 94 |
+
x = ctx.saved_tensors[0]
|
| 95 |
+
return g * torch.exp(torch.clamp(x, max=15))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
trunc_exp = _TruncExp.apply
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_activation(name) -> Callable:
|
| 102 |
+
if name is None:
|
| 103 |
+
return lambda x: x
|
| 104 |
+
name = name.lower()
|
| 105 |
+
if name == "none" or name == "linear" or name == "identity":
|
| 106 |
+
return lambda x: x
|
| 107 |
+
elif name == "lin2srgb":
|
| 108 |
+
return lambda x: torch.where(
|
| 109 |
+
x > 0.0031308,
|
| 110 |
+
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
|
| 111 |
+
12.92 * x,
|
| 112 |
+
).clamp(0.0, 1.0)
|
| 113 |
+
elif name == "exp":
|
| 114 |
+
return lambda x: torch.exp(x)
|
| 115 |
+
elif name == "shifted_exp":
|
| 116 |
+
return lambda x: torch.exp(x - 1.0)
|
| 117 |
+
elif name == "trunc_exp":
|
| 118 |
+
return trunc_exp
|
| 119 |
+
elif name == "shifted_trunc_exp":
|
| 120 |
+
return lambda x: trunc_exp(x - 1.0)
|
| 121 |
+
elif name == "sigmoid":
|
| 122 |
+
return lambda x: torch.sigmoid(x)
|
| 123 |
+
elif name == "tanh":
|
| 124 |
+
return lambda x: torch.tanh(x)
|
| 125 |
+
elif name == "shifted_softplus":
|
| 126 |
+
return lambda x: F.softplus(x - 1.0)
|
| 127 |
+
elif name == "scale_-11_01":
|
| 128 |
+
return lambda x: x * 0.5 + 0.5
|
| 129 |
+
elif name == "negative":
|
| 130 |
+
return lambda x: -x
|
| 131 |
+
elif name == "normalize_channel_last":
|
| 132 |
+
return lambda x: normalize(x)
|
| 133 |
+
elif name == "normalize_channel_first":
|
| 134 |
+
return lambda x: normalize(x, dim=1)
|
| 135 |
+
else:
|
| 136 |
+
try:
|
| 137 |
+
return getattr(F, name)
|
| 138 |
+
except AttributeError:
|
| 139 |
+
raise ValueError(f"Unknown activation function: {name}")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class HeadSpec:
|
| 144 |
+
name: str
|
| 145 |
+
out_channels: int
|
| 146 |
+
n_hidden_layers: int
|
| 147 |
+
output_activation: Optional[str] = None
|
| 148 |
+
out_bias: float = 0.0
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MaterialMLP(BaseModule):
|
| 152 |
+
@dataclass
|
| 153 |
+
class Config(BaseModule.Config):
|
| 154 |
+
in_channels: int = 120
|
| 155 |
+
n_neurons: int = 64
|
| 156 |
+
activation: str = "silu"
|
| 157 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
| 158 |
+
|
| 159 |
+
cfg: Config
|
| 160 |
+
|
| 161 |
+
def configure(self) -> None:
|
| 162 |
+
assert len(self.cfg.heads) > 0
|
| 163 |
+
heads = {}
|
| 164 |
+
for head in self.cfg.heads:
|
| 165 |
+
head_layers = []
|
| 166 |
+
for i in range(head.n_hidden_layers):
|
| 167 |
+
head_layers += [
|
| 168 |
+
nn.Linear(
|
| 169 |
+
self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
|
| 170 |
+
self.cfg.n_neurons,
|
| 171 |
+
),
|
| 172 |
+
self.make_activation(self.cfg.activation),
|
| 173 |
+
]
|
| 174 |
+
head_layers += [
|
| 175 |
+
nn.Linear(
|
| 176 |
+
self.cfg.n_neurons,
|
| 177 |
+
head.out_channels,
|
| 178 |
+
),
|
| 179 |
+
]
|
| 180 |
+
heads[head.name] = nn.Sequential(*head_layers)
|
| 181 |
+
self.heads = nn.ModuleDict(heads)
|
| 182 |
+
|
| 183 |
+
def make_activation(self, activation):
|
| 184 |
+
if activation == "relu":
|
| 185 |
+
return nn.ReLU(inplace=True)
|
| 186 |
+
elif activation == "silu":
|
| 187 |
+
return nn.SiLU(inplace=True)
|
| 188 |
+
else:
|
| 189 |
+
raise NotImplementedError
|
| 190 |
+
|
| 191 |
+
def keys(self):
|
| 192 |
+
return self.heads.keys()
|
| 193 |
+
|
| 194 |
+
def forward(
|
| 195 |
+
self, x, include: Optional[List] = None, exclude: Optional[List] = None
|
| 196 |
+
):
|
| 197 |
+
if include is not None and exclude is not None:
|
| 198 |
+
raise ValueError("Cannot specify both include and exclude.")
|
| 199 |
+
if include is not None:
|
| 200 |
+
heads = [h for h in self.cfg.heads if h.name in include]
|
| 201 |
+
elif exclude is not None:
|
| 202 |
+
heads = [h for h in self.cfg.heads if h.name not in exclude]
|
| 203 |
+
else:
|
| 204 |
+
heads = self.cfg.heads
|
| 205 |
+
|
| 206 |
+
out = {
|
| 207 |
+
head.name: get_activation(head.output_activation)(
|
| 208 |
+
self.heads[head.name](x) + head.out_bias
|
| 209 |
+
)
|
| 210 |
+
for head in heads
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
return out
|
sf3d/models/tokenizers/dinov2.py
ADDED
|
@@ -0,0 +1,1196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DINOv2 model."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
+
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.modeling_outputs import (
|
| 29 |
+
BackboneOutput,
|
| 30 |
+
BaseModelOutput,
|
| 31 |
+
BaseModelOutputWithPooling,
|
| 32 |
+
ImageClassifierOutput,
|
| 33 |
+
)
|
| 34 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 35 |
+
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
|
| 36 |
+
from transformers.pytorch_utils import (
|
| 37 |
+
find_pruneable_heads_and_indices,
|
| 38 |
+
prune_linear_layer,
|
| 39 |
+
)
|
| 40 |
+
from transformers.utils import (
|
| 41 |
+
add_code_sample_docstrings,
|
| 42 |
+
add_start_docstrings,
|
| 43 |
+
add_start_docstrings_to_model_forward,
|
| 44 |
+
logging,
|
| 45 |
+
replace_return_docstrings,
|
| 46 |
+
)
|
| 47 |
+
from transformers.utils.backbone_utils import BackboneMixin
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
# General docstring
|
| 52 |
+
_CONFIG_FOR_DOC = "Dinov2Config"
|
| 53 |
+
|
| 54 |
+
# Base docstring
|
| 55 |
+
_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
|
| 56 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
|
| 57 |
+
|
| 58 |
+
# Image classification docstring
|
| 59 |
+
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 63 |
+
"facebook/dinov2-base",
|
| 64 |
+
# See all DINOv2 models at https://huggingface.co/models?filter=dinov2
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Dinov2Embeddings(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, config: Dinov2Config) -> None:
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
| 77 |
+
# register as mask token as it's not used in optimization
|
| 78 |
+
# to avoid the use of find_unused_parameters_true
|
| 79 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
|
| 80 |
+
self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
|
| 81 |
+
self.patch_embeddings = Dinov2PatchEmbeddings(config)
|
| 82 |
+
num_patches = self.patch_embeddings.num_patches
|
| 83 |
+
self.position_embeddings = nn.Parameter(
|
| 84 |
+
torch.randn(1, num_patches + 1, config.hidden_size)
|
| 85 |
+
)
|
| 86 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 87 |
+
self.config = config
|
| 88 |
+
|
| 89 |
+
def interpolate_pos_encoding(
|
| 90 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
| 91 |
+
) -> torch.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
| 94 |
+
resolution images.
|
| 95 |
+
|
| 96 |
+
Source:
|
| 97 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
num_patches = embeddings.shape[1] - 1
|
| 101 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 102 |
+
if num_patches == num_positions and height == width:
|
| 103 |
+
return self.position_embeddings
|
| 104 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
| 105 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 106 |
+
dim = embeddings.shape[-1]
|
| 107 |
+
height = height // self.config.patch_size
|
| 108 |
+
width = width // self.config.patch_size
|
| 109 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 110 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 111 |
+
height, width = height + 0.1, width + 0.1
|
| 112 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
| 113 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
| 114 |
+
)
|
| 115 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 116 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 117 |
+
patch_pos_embed,
|
| 118 |
+
scale_factor=(
|
| 119 |
+
height / math.sqrt(num_positions),
|
| 120 |
+
width / math.sqrt(num_positions),
|
| 121 |
+
),
|
| 122 |
+
mode="bicubic",
|
| 123 |
+
align_corners=False,
|
| 124 |
+
)
|
| 125 |
+
if (
|
| 126 |
+
int(height) != patch_pos_embed.shape[-2]
|
| 127 |
+
or int(width) != patch_pos_embed.shape[-1]
|
| 128 |
+
):
|
| 129 |
+
raise ValueError(
|
| 130 |
+
"Width or height does not match with the interpolated position embeddings"
|
| 131 |
+
)
|
| 132 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 133 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
pixel_values: torch.Tensor,
|
| 138 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
| 139 |
+
) -> torch.Tensor:
|
| 140 |
+
batch_size, _, height, width = pixel_values.shape
|
| 141 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
| 142 |
+
embeddings = patch_embeddings
|
| 143 |
+
|
| 144 |
+
if bool_masked_pos is not None:
|
| 145 |
+
embeddings = torch.where(
|
| 146 |
+
bool_masked_pos.unsqueeze(-1),
|
| 147 |
+
self.mask_token.to(embeddings.dtype).unsqueeze(0),
|
| 148 |
+
embeddings,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# add the [CLS] token to the embedded patch tokens
|
| 152 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 153 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 154 |
+
|
| 155 |
+
# add positional encoding to each token
|
| 156 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
| 157 |
+
embeddings, height, width
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
embeddings = self.dropout(embeddings)
|
| 161 |
+
|
| 162 |
+
return embeddings
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class Dinov2PatchEmbeddings(nn.Module):
|
| 166 |
+
"""
|
| 167 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 168 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 169 |
+
Transformer.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, config):
|
| 173 |
+
super().__init__()
|
| 174 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 175 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 176 |
+
|
| 177 |
+
image_size = (
|
| 178 |
+
image_size
|
| 179 |
+
if isinstance(image_size, collections.abc.Iterable)
|
| 180 |
+
else (image_size, image_size)
|
| 181 |
+
)
|
| 182 |
+
patch_size = (
|
| 183 |
+
patch_size
|
| 184 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
| 185 |
+
else (patch_size, patch_size)
|
| 186 |
+
)
|
| 187 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
| 188 |
+
image_size[0] // patch_size[0]
|
| 189 |
+
)
|
| 190 |
+
self.image_size = image_size
|
| 191 |
+
self.patch_size = patch_size
|
| 192 |
+
self.num_channels = num_channels
|
| 193 |
+
self.num_patches = num_patches
|
| 194 |
+
|
| 195 |
+
self.projection = nn.Conv2d(
|
| 196 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 200 |
+
"""
|
| 201 |
+
num_channels = pixel_values.shape[1]
|
| 202 |
+
if num_channels != self.num_channels:
|
| 203 |
+
raise ValueError(
|
| 204 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 205 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
| 206 |
+
)
|
| 207 |
+
"""
|
| 208 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
| 209 |
+
return embeddings
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
|
| 213 |
+
class Dinov2SelfAttention(nn.Module):
|
| 214 |
+
def __init__(self, config: Dinov2Config) -> None:
|
| 215 |
+
super().__init__()
|
| 216 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
| 217 |
+
config, "embedding_size"
|
| 218 |
+
):
|
| 219 |
+
raise ValueError(
|
| 220 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
| 221 |
+
f"heads {config.num_attention_heads}."
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.num_attention_heads = config.num_attention_heads
|
| 225 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 226 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 227 |
+
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
| 228 |
+
|
| 229 |
+
self.query = nn.Linear(
|
| 230 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
| 231 |
+
)
|
| 232 |
+
self.key = nn.Linear(
|
| 233 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
| 234 |
+
)
|
| 235 |
+
self.value = nn.Linear(
|
| 236 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 240 |
+
|
| 241 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 242 |
+
new_x_shape = x.size()[:-1] + (
|
| 243 |
+
self.num_attention_heads,
|
| 244 |
+
self.attention_head_size,
|
| 245 |
+
)
|
| 246 |
+
x = x.view(new_x_shape)
|
| 247 |
+
return x.permute(0, 2, 1, 3)
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
hidden_states,
|
| 252 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 253 |
+
output_attentions: bool = False,
|
| 254 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 255 |
+
mixed_query_layer = self.query(hidden_states)
|
| 256 |
+
|
| 257 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
| 258 |
+
assert head_mask is None and not output_attentions
|
| 259 |
+
new_size = hidden_states.size()[:-1] + (
|
| 260 |
+
self.num_attention_heads,
|
| 261 |
+
self.attention_head_size,
|
| 262 |
+
)
|
| 263 |
+
key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
|
| 264 |
+
value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
|
| 265 |
+
query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
|
| 266 |
+
context_layer = F.scaled_dot_product_attention(
|
| 267 |
+
query_layer,
|
| 268 |
+
key_layer,
|
| 269 |
+
value_layer,
|
| 270 |
+
dropout_p=self.attention_probs_dropout_prob,
|
| 271 |
+
is_causal=False,
|
| 272 |
+
)
|
| 273 |
+
context_layer = context_layer.transpose(1, 2).reshape(
|
| 274 |
+
*hidden_states.size()[:-1], -1
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 278 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 279 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 280 |
+
|
| 281 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 282 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 283 |
+
|
| 284 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 285 |
+
|
| 286 |
+
# Normalize the attention scores to probabilities.
|
| 287 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 288 |
+
|
| 289 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 290 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 291 |
+
attention_probs = self.dropout(attention_probs)
|
| 292 |
+
|
| 293 |
+
# Mask heads if we want to
|
| 294 |
+
if head_mask is not None:
|
| 295 |
+
attention_probs = attention_probs * head_mask
|
| 296 |
+
|
| 297 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 298 |
+
|
| 299 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 300 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 301 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 302 |
+
|
| 303 |
+
outputs = (
|
| 304 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
return outputs
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
| 311 |
+
class Dinov2SelfOutput(nn.Module):
|
| 312 |
+
"""
|
| 313 |
+
The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
|
| 314 |
+
layernorm applied before each block.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def __init__(self, config: Dinov2Config) -> None:
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 320 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 321 |
+
|
| 322 |
+
def forward(
|
| 323 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
| 324 |
+
) -> torch.Tensor:
|
| 325 |
+
hidden_states = self.dense(hidden_states)
|
| 326 |
+
hidden_states = self.dropout(hidden_states)
|
| 327 |
+
|
| 328 |
+
return hidden_states
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
|
| 332 |
+
class Dinov2Attention(nn.Module):
|
| 333 |
+
def __init__(self, config: Dinov2Config) -> None:
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.attention = Dinov2SelfAttention(config)
|
| 336 |
+
self.output = Dinov2SelfOutput(config)
|
| 337 |
+
self.pruned_heads = set()
|
| 338 |
+
|
| 339 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
| 340 |
+
if len(heads) == 0:
|
| 341 |
+
return
|
| 342 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 343 |
+
heads,
|
| 344 |
+
self.attention.num_attention_heads,
|
| 345 |
+
self.attention.attention_head_size,
|
| 346 |
+
self.pruned_heads,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Prune linear layers
|
| 350 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 351 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 352 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 353 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 354 |
+
|
| 355 |
+
# Update hyper params and store pruned heads
|
| 356 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(
|
| 357 |
+
heads
|
| 358 |
+
)
|
| 359 |
+
self.attention.all_head_size = (
|
| 360 |
+
self.attention.attention_head_size * self.attention.num_attention_heads
|
| 361 |
+
)
|
| 362 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 363 |
+
|
| 364 |
+
def forward(
|
| 365 |
+
self,
|
| 366 |
+
hidden_states: torch.Tensor,
|
| 367 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 368 |
+
output_attentions: bool = False,
|
| 369 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 370 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
| 371 |
+
|
| 372 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 373 |
+
|
| 374 |
+
outputs = (attention_output,) + self_outputs[
|
| 375 |
+
1:
|
| 376 |
+
] # add attentions if we output them
|
| 377 |
+
return outputs
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Dinov2LayerScale(nn.Module):
|
| 381 |
+
def __init__(self, config) -> None:
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.lambda1 = nn.Parameter(
|
| 384 |
+
config.layerscale_value * torch.ones(config.hidden_size)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 388 |
+
return hidden_state * self.lambda1
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
| 392 |
+
def drop_path(
|
| 393 |
+
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
| 394 |
+
) -> torch.Tensor:
|
| 395 |
+
"""
|
| 396 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 397 |
+
|
| 398 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 399 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 400 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 401 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 402 |
+
argument.
|
| 403 |
+
"""
|
| 404 |
+
if drop_prob == 0.0 or not training:
|
| 405 |
+
return input
|
| 406 |
+
keep_prob = 1 - drop_prob
|
| 407 |
+
shape = (input.shape[0],) + (1,) * (
|
| 408 |
+
input.ndim - 1
|
| 409 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 410 |
+
random_tensor = keep_prob + torch.rand(
|
| 411 |
+
shape, dtype=input.dtype, device=input.device
|
| 412 |
+
)
|
| 413 |
+
random_tensor.floor_() # binarize
|
| 414 |
+
output = input.div(keep_prob) * random_tensor
|
| 415 |
+
return output
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
|
| 419 |
+
class Dinov2DropPath(nn.Module):
|
| 420 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 421 |
+
|
| 422 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.drop_prob = drop_prob
|
| 425 |
+
|
| 426 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 427 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 428 |
+
|
| 429 |
+
def extra_repr(self) -> str:
|
| 430 |
+
return "p={}".format(self.drop_prob)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class Dinov2MLP(nn.Module):
|
| 434 |
+
def __init__(self, config) -> None:
|
| 435 |
+
super().__init__()
|
| 436 |
+
in_features = out_features = config.hidden_size
|
| 437 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
| 438 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
|
| 439 |
+
if isinstance(config.hidden_act, str):
|
| 440 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 441 |
+
else:
|
| 442 |
+
self.activation = config.hidden_act
|
| 443 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
|
| 444 |
+
|
| 445 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 446 |
+
hidden_state = self.fc1(hidden_state)
|
| 447 |
+
hidden_state = self.activation(hidden_state)
|
| 448 |
+
hidden_state = self.fc2(hidden_state)
|
| 449 |
+
return hidden_state
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class Dinov2SwiGLUFFN(nn.Module):
|
| 453 |
+
def __init__(self, config) -> None:
|
| 454 |
+
super().__init__()
|
| 455 |
+
in_features = out_features = config.hidden_size
|
| 456 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
| 457 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 458 |
+
|
| 459 |
+
self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
|
| 460 |
+
self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
|
| 461 |
+
|
| 462 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 463 |
+
hidden_state = self.weights_in(hidden_state)
|
| 464 |
+
x1, x2 = hidden_state.chunk(2, dim=-1)
|
| 465 |
+
hidden = nn.functional.silu(x1) * x2
|
| 466 |
+
return self.weights_out(hidden)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class Dinov2Layer(nn.Module):
|
| 470 |
+
"""This corresponds to the Block class in the original implementation."""
|
| 471 |
+
|
| 472 |
+
def __init__(self, config: Dinov2Config) -> None:
|
| 473 |
+
super().__init__()
|
| 474 |
+
|
| 475 |
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 476 |
+
self.norm1_modulation = None
|
| 477 |
+
self.attention = Dinov2Attention(config)
|
| 478 |
+
self.layer_scale1 = Dinov2LayerScale(config)
|
| 479 |
+
self.drop_path1 = (
|
| 480 |
+
Dinov2DropPath(config.drop_path_rate)
|
| 481 |
+
if config.drop_path_rate > 0.0
|
| 482 |
+
else nn.Identity()
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 486 |
+
self.norm2_modulation = None
|
| 487 |
+
|
| 488 |
+
if config.use_swiglu_ffn:
|
| 489 |
+
self.mlp = Dinov2SwiGLUFFN(config)
|
| 490 |
+
else:
|
| 491 |
+
self.mlp = Dinov2MLP(config)
|
| 492 |
+
self.layer_scale2 = Dinov2LayerScale(config)
|
| 493 |
+
self.drop_path2 = (
|
| 494 |
+
Dinov2DropPath(config.drop_path_rate)
|
| 495 |
+
if config.drop_path_rate > 0.0
|
| 496 |
+
else nn.Identity()
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def forward(
|
| 500 |
+
self,
|
| 501 |
+
hidden_states: torch.Tensor,
|
| 502 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 503 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
| 504 |
+
output_attentions: bool = False,
|
| 505 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 506 |
+
hidden_states_norm = self.norm1(hidden_states)
|
| 507 |
+
if self.norm1_modulation is not None:
|
| 508 |
+
assert modulation_cond is not None
|
| 509 |
+
hidden_states_norm = self.norm1_modulation(
|
| 510 |
+
hidden_states_norm, modulation_cond
|
| 511 |
+
)
|
| 512 |
+
self_attention_outputs = self.attention(
|
| 513 |
+
hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
|
| 514 |
+
head_mask,
|
| 515 |
+
output_attentions=output_attentions,
|
| 516 |
+
)
|
| 517 |
+
attention_output = self_attention_outputs[0]
|
| 518 |
+
|
| 519 |
+
attention_output = self.layer_scale1(attention_output)
|
| 520 |
+
outputs = self_attention_outputs[
|
| 521 |
+
1:
|
| 522 |
+
] # add self attentions if we output attention weights
|
| 523 |
+
|
| 524 |
+
# first residual connection
|
| 525 |
+
hidden_states = attention_output + hidden_states
|
| 526 |
+
|
| 527 |
+
# in Dinov2, layernorm is also applied after self-attention
|
| 528 |
+
layer_output = self.norm2(hidden_states)
|
| 529 |
+
if self.norm2_modulation is not None:
|
| 530 |
+
assert modulation_cond is not None
|
| 531 |
+
layer_output = self.norm2_modulation(layer_output, modulation_cond)
|
| 532 |
+
layer_output = self.mlp(layer_output)
|
| 533 |
+
layer_output = self.layer_scale2(layer_output)
|
| 534 |
+
|
| 535 |
+
# second residual connection
|
| 536 |
+
layer_output = layer_output + hidden_states
|
| 537 |
+
|
| 538 |
+
outputs = (layer_output,) + outputs
|
| 539 |
+
|
| 540 |
+
return outputs
|
| 541 |
+
|
| 542 |
+
def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
|
| 543 |
+
self.norm1_modulation = norm1_mod
|
| 544 |
+
self.norm2_modulation = norm2_mod
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
|
| 548 |
+
class Dinov2Encoder(nn.Module):
|
| 549 |
+
def __init__(self, config: Dinov2Config) -> None:
|
| 550 |
+
super().__init__()
|
| 551 |
+
self.config = config
|
| 552 |
+
self.layer = nn.ModuleList(
|
| 553 |
+
[Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
|
| 554 |
+
)
|
| 555 |
+
self.gradient_checkpointing = False
|
| 556 |
+
|
| 557 |
+
def forward(
|
| 558 |
+
self,
|
| 559 |
+
hidden_states: torch.Tensor,
|
| 560 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 561 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
| 562 |
+
output_attentions: bool = False,
|
| 563 |
+
output_hidden_states: bool = False,
|
| 564 |
+
return_dict: bool = True,
|
| 565 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 566 |
+
all_hidden_states = () if output_hidden_states else None
|
| 567 |
+
all_self_attentions = () if output_attentions else None
|
| 568 |
+
|
| 569 |
+
for i, layer_module in enumerate(self.layer):
|
| 570 |
+
if output_hidden_states:
|
| 571 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 572 |
+
|
| 573 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 574 |
+
|
| 575 |
+
if self.gradient_checkpointing and self.training:
|
| 576 |
+
|
| 577 |
+
def create_custom_forward(module):
|
| 578 |
+
def custom_forward(*inputs):
|
| 579 |
+
return module(*inputs, output_attentions)
|
| 580 |
+
|
| 581 |
+
return custom_forward
|
| 582 |
+
|
| 583 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 584 |
+
create_custom_forward(layer_module),
|
| 585 |
+
hidden_states,
|
| 586 |
+
layer_head_mask,
|
| 587 |
+
modulation_cond,
|
| 588 |
+
use_reentrant=False,
|
| 589 |
+
)
|
| 590 |
+
else:
|
| 591 |
+
layer_outputs = layer_module(
|
| 592 |
+
hidden_states, layer_head_mask, modulation_cond, output_attentions
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
hidden_states = layer_outputs[0]
|
| 596 |
+
|
| 597 |
+
if output_attentions:
|
| 598 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 599 |
+
|
| 600 |
+
if output_hidden_states:
|
| 601 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 602 |
+
|
| 603 |
+
if not return_dict:
|
| 604 |
+
return tuple(
|
| 605 |
+
v
|
| 606 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
| 607 |
+
if v is not None
|
| 608 |
+
)
|
| 609 |
+
return BaseModelOutput(
|
| 610 |
+
last_hidden_state=hidden_states,
|
| 611 |
+
hidden_states=all_hidden_states,
|
| 612 |
+
attentions=all_self_attentions,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class Dinov2PreTrainedModel(PreTrainedModel):
|
| 617 |
+
"""
|
| 618 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 619 |
+
models.
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
config_class = Dinov2Config
|
| 623 |
+
base_model_prefix = "dinov2"
|
| 624 |
+
main_input_name = "pixel_values"
|
| 625 |
+
supports_gradient_checkpointing = True
|
| 626 |
+
|
| 627 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 628 |
+
"""Initialize the weights"""
|
| 629 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 630 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
| 631 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
| 632 |
+
module.weight.data = nn.init.trunc_normal_(
|
| 633 |
+
module.weight.data.to(torch.float32),
|
| 634 |
+
mean=0.0,
|
| 635 |
+
std=self.config.initializer_range,
|
| 636 |
+
).to(module.weight.dtype)
|
| 637 |
+
if module.bias is not None:
|
| 638 |
+
module.bias.data.zero_()
|
| 639 |
+
elif isinstance(module, nn.LayerNorm):
|
| 640 |
+
module.bias.data.zero_()
|
| 641 |
+
module.weight.data.fill_(1.0)
|
| 642 |
+
elif isinstance(module, Dinov2Embeddings):
|
| 643 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
| 644 |
+
module.position_embeddings.data.to(torch.float32),
|
| 645 |
+
mean=0.0,
|
| 646 |
+
std=self.config.initializer_range,
|
| 647 |
+
).to(module.position_embeddings.dtype)
|
| 648 |
+
|
| 649 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
| 650 |
+
module.cls_token.data.to(torch.float32),
|
| 651 |
+
mean=0.0,
|
| 652 |
+
std=self.config.initializer_range,
|
| 653 |
+
).to(module.cls_token.dtype)
|
| 654 |
+
|
| 655 |
+
def _set_gradient_checkpointing(
|
| 656 |
+
self, module: Dinov2Encoder, value: bool = False
|
| 657 |
+
) -> None:
|
| 658 |
+
if isinstance(module, Dinov2Encoder):
|
| 659 |
+
module.gradient_checkpointing = value
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
DINOV2_START_DOCSTRING = r"""
|
| 663 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 664 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 665 |
+
behavior.
|
| 666 |
+
|
| 667 |
+
Parameters:
|
| 668 |
+
config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
|
| 669 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 670 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 671 |
+
"""
|
| 672 |
+
|
| 673 |
+
DINOV2_BASE_INPUTS_DOCSTRING = r"""
|
| 674 |
+
Args:
|
| 675 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 676 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 677 |
+
[`BitImageProcessor.preprocess`] for details.
|
| 678 |
+
|
| 679 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
|
| 680 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
|
| 681 |
+
pre-training.
|
| 682 |
+
|
| 683 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 684 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 685 |
+
|
| 686 |
+
- 1 indicates the head is **not masked**,
|
| 687 |
+
- 0 indicates the head is **masked**.
|
| 688 |
+
|
| 689 |
+
output_attentions (`bool`, *optional*):
|
| 690 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 691 |
+
tensors for more detail.
|
| 692 |
+
output_hidden_states (`bool`, *optional*):
|
| 693 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 694 |
+
more detail.
|
| 695 |
+
return_dict (`bool`, *optional*):
|
| 696 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 697 |
+
"""
|
| 698 |
+
|
| 699 |
+
DINOV2_INPUTS_DOCSTRING = r"""
|
| 700 |
+
Args:
|
| 701 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 702 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 703 |
+
[`BitImageProcessor.preprocess`] for details.
|
| 704 |
+
|
| 705 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 706 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 707 |
+
|
| 708 |
+
- 1 indicates the head is **not masked**,
|
| 709 |
+
- 0 indicates the head is **masked**.
|
| 710 |
+
|
| 711 |
+
output_attentions (`bool`, *optional*):
|
| 712 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 713 |
+
tensors for more detail.
|
| 714 |
+
output_hidden_states (`bool`, *optional*):
|
| 715 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 716 |
+
more detail.
|
| 717 |
+
return_dict (`bool`, *optional*):
|
| 718 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 719 |
+
"""
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
@dataclass
|
| 723 |
+
class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
|
| 724 |
+
patch_embeddings: Optional[torch.FloatTensor] = None
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
@add_start_docstrings(
|
| 728 |
+
"The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
|
| 729 |
+
DINOV2_START_DOCSTRING,
|
| 730 |
+
)
|
| 731 |
+
class Dinov2Model(Dinov2PreTrainedModel):
|
| 732 |
+
def __init__(self, config: Dinov2Config):
|
| 733 |
+
super().__init__(config)
|
| 734 |
+
self.config = config
|
| 735 |
+
|
| 736 |
+
self.embeddings = Dinov2Embeddings(config)
|
| 737 |
+
self.encoder = Dinov2Encoder(config)
|
| 738 |
+
|
| 739 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 740 |
+
|
| 741 |
+
# Initialize weights and apply final processing
|
| 742 |
+
self.post_init()
|
| 743 |
+
|
| 744 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
| 745 |
+
return self.embeddings.patch_embeddings
|
| 746 |
+
|
| 747 |
+
def expand_input_channels(self, extra_input_channels: int) -> None:
|
| 748 |
+
if extra_input_channels == 0:
|
| 749 |
+
return
|
| 750 |
+
conv_old = self.embeddings.patch_embeddings.projection
|
| 751 |
+
conv_new = nn.Conv2d(
|
| 752 |
+
self.config.num_channels + extra_input_channels,
|
| 753 |
+
self.config.hidden_size,
|
| 754 |
+
kernel_size=self.config.patch_size,
|
| 755 |
+
stride=self.config.patch_size,
|
| 756 |
+
).to(self.device)
|
| 757 |
+
with torch.no_grad():
|
| 758 |
+
conv_new.weight[:, :3] = conv_old.weight
|
| 759 |
+
conv_new.bias = conv_old.bias
|
| 760 |
+
self.embeddings.patch_embeddings.projection = conv_new
|
| 761 |
+
del conv_old
|
| 762 |
+
|
| 763 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
| 764 |
+
"""
|
| 765 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 766 |
+
class PreTrainedModel
|
| 767 |
+
"""
|
| 768 |
+
for layer, heads in heads_to_prune.items():
|
| 769 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 770 |
+
|
| 771 |
+
@add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
|
| 772 |
+
@add_code_sample_docstrings(
|
| 773 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 774 |
+
output_type=BaseModelOutputWithPooling,
|
| 775 |
+
config_class=_CONFIG_FOR_DOC,
|
| 776 |
+
modality="vision",
|
| 777 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 778 |
+
)
|
| 779 |
+
def forward(
|
| 780 |
+
self,
|
| 781 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 782 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
| 783 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 784 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
| 785 |
+
output_attentions: Optional[bool] = None,
|
| 786 |
+
output_hidden_states: Optional[bool] = None,
|
| 787 |
+
return_dict: Optional[bool] = None,
|
| 788 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 789 |
+
output_attentions = (
|
| 790 |
+
output_attentions
|
| 791 |
+
if output_attentions is not None
|
| 792 |
+
else self.config.output_attentions
|
| 793 |
+
)
|
| 794 |
+
output_hidden_states = (
|
| 795 |
+
output_hidden_states
|
| 796 |
+
if output_hidden_states is not None
|
| 797 |
+
else self.config.output_hidden_states
|
| 798 |
+
)
|
| 799 |
+
return_dict = (
|
| 800 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
if pixel_values is None:
|
| 804 |
+
raise ValueError("You have to specify pixel_values")
|
| 805 |
+
|
| 806 |
+
# Prepare head mask if needed
|
| 807 |
+
# 1.0 in head_mask indicate we keep the head
|
| 808 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 809 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 810 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 811 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 812 |
+
|
| 813 |
+
embedding_output = self.embeddings(
|
| 814 |
+
pixel_values, bool_masked_pos=bool_masked_pos
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
encoder_outputs = self.encoder(
|
| 818 |
+
embedding_output,
|
| 819 |
+
head_mask=head_mask,
|
| 820 |
+
modulation_cond=modulation_cond,
|
| 821 |
+
output_attentions=output_attentions,
|
| 822 |
+
output_hidden_states=output_hidden_states,
|
| 823 |
+
return_dict=return_dict,
|
| 824 |
+
)
|
| 825 |
+
sequence_output = encoder_outputs[0]
|
| 826 |
+
sequence_output = self.layernorm(sequence_output)
|
| 827 |
+
pooled_output = sequence_output[:, 0, :]
|
| 828 |
+
|
| 829 |
+
if not return_dict:
|
| 830 |
+
head_outputs = (sequence_output, pooled_output)
|
| 831 |
+
return head_outputs + encoder_outputs[1:]
|
| 832 |
+
|
| 833 |
+
return CustomBaseModelOutputWithPooling(
|
| 834 |
+
last_hidden_state=sequence_output,
|
| 835 |
+
pooler_output=pooled_output,
|
| 836 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 837 |
+
attentions=encoder_outputs.attentions,
|
| 838 |
+
patch_embeddings=embedding_output,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
def set_gradient_checkpointing(self, value: bool = False) -> None:
|
| 842 |
+
self._set_gradient_checkpointing(self.encoder, value)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
@add_start_docstrings(
|
| 846 |
+
"""
|
| 847 |
+
Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
|
| 848 |
+
of the [CLS] token) e.g. for ImageNet.
|
| 849 |
+
""",
|
| 850 |
+
DINOV2_START_DOCSTRING,
|
| 851 |
+
)
|
| 852 |
+
class Dinov2ForImageClassification(Dinov2PreTrainedModel):
|
| 853 |
+
def __init__(self, config: Dinov2Config) -> None:
|
| 854 |
+
super().__init__(config)
|
| 855 |
+
|
| 856 |
+
self.num_labels = config.num_labels
|
| 857 |
+
self.dinov2 = Dinov2Model(config)
|
| 858 |
+
|
| 859 |
+
# Classifier head
|
| 860 |
+
self.classifier = (
|
| 861 |
+
nn.Linear(config.hidden_size * 2, config.num_labels)
|
| 862 |
+
if config.num_labels > 0
|
| 863 |
+
else nn.Identity()
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
# Initialize weights and apply final processing
|
| 867 |
+
self.post_init()
|
| 868 |
+
|
| 869 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
| 870 |
+
@add_code_sample_docstrings(
|
| 871 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 872 |
+
output_type=ImageClassifierOutput,
|
| 873 |
+
config_class=_CONFIG_FOR_DOC,
|
| 874 |
+
)
|
| 875 |
+
def forward(
|
| 876 |
+
self,
|
| 877 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 878 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 879 |
+
labels: Optional[torch.Tensor] = None,
|
| 880 |
+
output_attentions: Optional[bool] = None,
|
| 881 |
+
output_hidden_states: Optional[bool] = None,
|
| 882 |
+
return_dict: Optional[bool] = None,
|
| 883 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
| 884 |
+
r"""
|
| 885 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 886 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 887 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 888 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 889 |
+
"""
|
| 890 |
+
return_dict = (
|
| 891 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
outputs = self.dinov2(
|
| 895 |
+
pixel_values,
|
| 896 |
+
head_mask=head_mask,
|
| 897 |
+
output_attentions=output_attentions,
|
| 898 |
+
output_hidden_states=output_hidden_states,
|
| 899 |
+
return_dict=return_dict,
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
|
| 903 |
+
|
| 904 |
+
cls_token = sequence_output[:, 0]
|
| 905 |
+
patch_tokens = sequence_output[:, 1:]
|
| 906 |
+
|
| 907 |
+
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
|
| 908 |
+
|
| 909 |
+
logits = self.classifier(linear_input)
|
| 910 |
+
|
| 911 |
+
loss = None
|
| 912 |
+
if labels is not None:
|
| 913 |
+
# move labels to correct device to enable model parallelism
|
| 914 |
+
labels = labels.to(logits.device)
|
| 915 |
+
if self.config.problem_type is None:
|
| 916 |
+
if self.num_labels == 1:
|
| 917 |
+
self.config.problem_type = "regression"
|
| 918 |
+
elif self.num_labels > 1 and (
|
| 919 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
| 920 |
+
):
|
| 921 |
+
self.config.problem_type = "single_label_classification"
|
| 922 |
+
else:
|
| 923 |
+
self.config.problem_type = "multi_label_classification"
|
| 924 |
+
|
| 925 |
+
if self.config.problem_type == "regression":
|
| 926 |
+
loss_fct = MSELoss()
|
| 927 |
+
if self.num_labels == 1:
|
| 928 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 929 |
+
else:
|
| 930 |
+
loss = loss_fct(logits, labels)
|
| 931 |
+
elif self.config.problem_type == "single_label_classification":
|
| 932 |
+
loss_fct = CrossEntropyLoss()
|
| 933 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 934 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 935 |
+
loss_fct = BCEWithLogitsLoss()
|
| 936 |
+
loss = loss_fct(logits, labels)
|
| 937 |
+
|
| 938 |
+
if not return_dict:
|
| 939 |
+
output = (logits,) + outputs[2:]
|
| 940 |
+
return ((loss,) + output) if loss is not None else output
|
| 941 |
+
|
| 942 |
+
return ImageClassifierOutput(
|
| 943 |
+
loss=loss,
|
| 944 |
+
logits=logits,
|
| 945 |
+
hidden_states=outputs.hidden_states,
|
| 946 |
+
attentions=outputs.attentions,
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
@add_start_docstrings(
|
| 951 |
+
"""
|
| 952 |
+
Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
|
| 953 |
+
""",
|
| 954 |
+
DINOV2_START_DOCSTRING,
|
| 955 |
+
)
|
| 956 |
+
class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
|
| 957 |
+
def __init__(self, config):
|
| 958 |
+
super().__init__(config)
|
| 959 |
+
super()._init_backbone(config)
|
| 960 |
+
|
| 961 |
+
self.num_features = [
|
| 962 |
+
config.hidden_size for _ in range(config.num_hidden_layers + 1)
|
| 963 |
+
]
|
| 964 |
+
self.embeddings = Dinov2Embeddings(config)
|
| 965 |
+
self.encoder = Dinov2Encoder(config)
|
| 966 |
+
|
| 967 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 968 |
+
|
| 969 |
+
# Initialize weights and apply final processing
|
| 970 |
+
self.post_init()
|
| 971 |
+
|
| 972 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
| 973 |
+
return self.embeddings.patch_embeddings
|
| 974 |
+
|
| 975 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
| 976 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
| 977 |
+
def forward(
|
| 978 |
+
self,
|
| 979 |
+
pixel_values: torch.Tensor,
|
| 980 |
+
output_hidden_states: Optional[bool] = None,
|
| 981 |
+
output_attentions: Optional[bool] = None,
|
| 982 |
+
return_dict: Optional[bool] = None,
|
| 983 |
+
) -> BackboneOutput:
|
| 984 |
+
"""
|
| 985 |
+
Returns:
|
| 986 |
+
|
| 987 |
+
Examples:
|
| 988 |
+
|
| 989 |
+
```python
|
| 990 |
+
>>> from transformers import AutoImageProcessor, AutoBackbone
|
| 991 |
+
>>> import torch
|
| 992 |
+
>>> from PIL import Image
|
| 993 |
+
>>> import requests
|
| 994 |
+
|
| 995 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 996 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 997 |
+
|
| 998 |
+
>>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
| 999 |
+
>>> model = AutoBackbone.from_pretrained(
|
| 1000 |
+
... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
|
| 1001 |
+
... )
|
| 1002 |
+
|
| 1003 |
+
>>> inputs = processor(image, return_tensors="pt")
|
| 1004 |
+
|
| 1005 |
+
>>> outputs = model(**inputs)
|
| 1006 |
+
>>> feature_maps = outputs.feature_maps
|
| 1007 |
+
>>> list(feature_maps[-1].shape)
|
| 1008 |
+
[1, 768, 16, 16]
|
| 1009 |
+
```"""
|
| 1010 |
+
return_dict = (
|
| 1011 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1012 |
+
)
|
| 1013 |
+
output_hidden_states = (
|
| 1014 |
+
output_hidden_states
|
| 1015 |
+
if output_hidden_states is not None
|
| 1016 |
+
else self.config.output_hidden_states
|
| 1017 |
+
)
|
| 1018 |
+
output_attentions = (
|
| 1019 |
+
output_attentions
|
| 1020 |
+
if output_attentions is not None
|
| 1021 |
+
else self.config.output_attentions
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
embedding_output = self.embeddings(pixel_values)
|
| 1025 |
+
|
| 1026 |
+
outputs = self.encoder(
|
| 1027 |
+
embedding_output,
|
| 1028 |
+
output_hidden_states=True,
|
| 1029 |
+
output_attentions=output_attentions,
|
| 1030 |
+
return_dict=return_dict,
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 1034 |
+
|
| 1035 |
+
feature_maps = ()
|
| 1036 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
| 1037 |
+
if stage in self.out_features:
|
| 1038 |
+
if self.config.apply_layernorm:
|
| 1039 |
+
hidden_state = self.layernorm(hidden_state)
|
| 1040 |
+
if self.config.reshape_hidden_states:
|
| 1041 |
+
batch_size, _, height, width = pixel_values.shape
|
| 1042 |
+
patch_size = self.config.patch_size
|
| 1043 |
+
hidden_state = hidden_state[:, 1:, :].reshape(
|
| 1044 |
+
batch_size, width // patch_size, height // patch_size, -1
|
| 1045 |
+
)
|
| 1046 |
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
| 1047 |
+
feature_maps += (hidden_state,)
|
| 1048 |
+
|
| 1049 |
+
if not return_dict:
|
| 1050 |
+
if output_hidden_states:
|
| 1051 |
+
output = (feature_maps,) + outputs[1:]
|
| 1052 |
+
else:
|
| 1053 |
+
output = (feature_maps,) + outputs[2:]
|
| 1054 |
+
return output
|
| 1055 |
+
|
| 1056 |
+
return BackboneOutput(
|
| 1057 |
+
feature_maps=feature_maps,
|
| 1058 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 1059 |
+
attentions=outputs.attentions if output_attentions else None,
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
class CustomPatchEmbeddings(nn.Module):
|
| 1064 |
+
"""
|
| 1065 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 1066 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 1067 |
+
Transformer.
|
| 1068 |
+
"""
|
| 1069 |
+
|
| 1070 |
+
def __init__(
|
| 1071 |
+
self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
|
| 1072 |
+
):
|
| 1073 |
+
super().__init__()
|
| 1074 |
+
|
| 1075 |
+
image_size = (
|
| 1076 |
+
image_size
|
| 1077 |
+
if isinstance(image_size, collections.abc.Iterable)
|
| 1078 |
+
else (image_size, image_size)
|
| 1079 |
+
)
|
| 1080 |
+
patch_size = (
|
| 1081 |
+
patch_size
|
| 1082 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
| 1083 |
+
else (patch_size, patch_size)
|
| 1084 |
+
)
|
| 1085 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
| 1086 |
+
image_size[0] // patch_size[0]
|
| 1087 |
+
)
|
| 1088 |
+
self.image_size = image_size
|
| 1089 |
+
self.patch_size = patch_size
|
| 1090 |
+
self.num_channels = num_channels
|
| 1091 |
+
self.num_patches = num_patches
|
| 1092 |
+
|
| 1093 |
+
self.projection = nn.Conv2d(
|
| 1094 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
| 1095 |
+
)
|
| 1096 |
+
|
| 1097 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 1098 |
+
num_channels = pixel_values.shape[1]
|
| 1099 |
+
if num_channels != self.num_channels:
|
| 1100 |
+
raise ValueError(
|
| 1101 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 1102 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
| 1103 |
+
)
|
| 1104 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
| 1105 |
+
return embeddings
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
class CustomEmbeddings(nn.Module):
|
| 1109 |
+
"""
|
| 1110 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
| 1111 |
+
"""
|
| 1112 |
+
|
| 1113 |
+
def __init__(
|
| 1114 |
+
self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
|
| 1115 |
+
) -> None:
|
| 1116 |
+
super().__init__()
|
| 1117 |
+
|
| 1118 |
+
self.image_size = image_size
|
| 1119 |
+
self.patch_size = patch_size
|
| 1120 |
+
self.num_channels = num_channels
|
| 1121 |
+
self.hidden_size = hidden_size
|
| 1122 |
+
|
| 1123 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
| 1124 |
+
|
| 1125 |
+
self.patch_embeddings = CustomPatchEmbeddings(
|
| 1126 |
+
image_size, patch_size, num_channels, hidden_size
|
| 1127 |
+
)
|
| 1128 |
+
num_patches = self.patch_embeddings.num_patches
|
| 1129 |
+
self.position_embeddings = nn.Parameter(
|
| 1130 |
+
torch.randn(1, num_patches + 1, self.hidden_size)
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
def interpolate_pos_encoding(
|
| 1134 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
| 1135 |
+
) -> torch.Tensor:
|
| 1136 |
+
"""
|
| 1137 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
| 1138 |
+
resolution images.
|
| 1139 |
+
|
| 1140 |
+
Source:
|
| 1141 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 1142 |
+
"""
|
| 1143 |
+
|
| 1144 |
+
num_patches = embeddings.shape[1] - 1
|
| 1145 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 1146 |
+
if num_patches == num_positions and height == width:
|
| 1147 |
+
return self.position_embeddings
|
| 1148 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
| 1149 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 1150 |
+
dim = embeddings.shape[-1]
|
| 1151 |
+
height = height // self.patch_size
|
| 1152 |
+
width = width // self.patch_size
|
| 1153 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 1154 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 1155 |
+
height, width = height + 0.1, width + 0.1
|
| 1156 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
| 1157 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
| 1158 |
+
)
|
| 1159 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 1160 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 1161 |
+
patch_pos_embed,
|
| 1162 |
+
scale_factor=(
|
| 1163 |
+
height / math.sqrt(num_positions),
|
| 1164 |
+
width / math.sqrt(num_positions),
|
| 1165 |
+
),
|
| 1166 |
+
mode="bicubic",
|
| 1167 |
+
align_corners=False,
|
| 1168 |
+
)
|
| 1169 |
+
if (
|
| 1170 |
+
int(height) != patch_pos_embed.shape[-2]
|
| 1171 |
+
or int(width) != patch_pos_embed.shape[-1]
|
| 1172 |
+
):
|
| 1173 |
+
raise ValueError(
|
| 1174 |
+
"Width or height does not match with the interpolated position embeddings"
|
| 1175 |
+
)
|
| 1176 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 1177 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 1178 |
+
|
| 1179 |
+
def forward(
|
| 1180 |
+
self,
|
| 1181 |
+
pixel_values: torch.Tensor,
|
| 1182 |
+
) -> torch.Tensor:
|
| 1183 |
+
batch_size, _, height, width = pixel_values.shape
|
| 1184 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
| 1185 |
+
embeddings = patch_embeddings
|
| 1186 |
+
|
| 1187 |
+
# add the [CLS] token to the embedded patch tokens
|
| 1188 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 1189 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 1190 |
+
|
| 1191 |
+
# add positional encoding to each token
|
| 1192 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
| 1193 |
+
embeddings, height, width
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
return embeddings
|
sf3d/models/tokenizers/image.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from jaxtyping import Float
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from sf3d.models.tokenizers.dinov2 import Dinov2Model
|
| 11 |
+
from sf3d.models.transformers.attention import Modulation
|
| 12 |
+
from sf3d.models.utils import BaseModule
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DINOV2SingleImageTokenizer(BaseModule):
|
| 16 |
+
@dataclass
|
| 17 |
+
class Config(BaseModule.Config):
|
| 18 |
+
pretrained_model_name_or_path: str = "facebook/dinov2-large"
|
| 19 |
+
width: int = 512
|
| 20 |
+
height: int = 512
|
| 21 |
+
modulation_cond_dim: int = 768
|
| 22 |
+
|
| 23 |
+
cfg: Config
|
| 24 |
+
|
| 25 |
+
def configure(self) -> None:
|
| 26 |
+
self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
|
| 27 |
+
|
| 28 |
+
for p in self.model.parameters():
|
| 29 |
+
p.requires_grad_(False)
|
| 30 |
+
self.model.eval()
|
| 31 |
+
|
| 32 |
+
self.model.set_gradient_checkpointing(False)
|
| 33 |
+
|
| 34 |
+
# add modulation
|
| 35 |
+
modulations = []
|
| 36 |
+
for layer in self.model.encoder.layer:
|
| 37 |
+
norm1_modulation = Modulation(
|
| 38 |
+
self.model.config.hidden_size,
|
| 39 |
+
self.cfg.modulation_cond_dim,
|
| 40 |
+
zero_init=True,
|
| 41 |
+
single_layer=True,
|
| 42 |
+
)
|
| 43 |
+
norm2_modulation = Modulation(
|
| 44 |
+
self.model.config.hidden_size,
|
| 45 |
+
self.cfg.modulation_cond_dim,
|
| 46 |
+
zero_init=True,
|
| 47 |
+
single_layer=True,
|
| 48 |
+
)
|
| 49 |
+
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
|
| 50 |
+
modulations += [norm1_modulation, norm2_modulation]
|
| 51 |
+
self.modulations = nn.ModuleList(modulations)
|
| 52 |
+
|
| 53 |
+
self.register_buffer(
|
| 54 |
+
"image_mean",
|
| 55 |
+
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
|
| 56 |
+
persistent=False,
|
| 57 |
+
)
|
| 58 |
+
self.register_buffer(
|
| 59 |
+
"image_std",
|
| 60 |
+
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
|
| 61 |
+
persistent=False,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
images: Float[Tensor, "B *N C H W"],
|
| 67 |
+
modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
|
| 68 |
+
**kwargs,
|
| 69 |
+
) -> Float[Tensor, "B *N Ct Nt"]:
|
| 70 |
+
model = self.model
|
| 71 |
+
|
| 72 |
+
packed = False
|
| 73 |
+
if images.ndim == 4:
|
| 74 |
+
packed = True
|
| 75 |
+
images = images.unsqueeze(1)
|
| 76 |
+
if modulation_cond is not None:
|
| 77 |
+
assert modulation_cond.ndim == 2
|
| 78 |
+
modulation_cond = modulation_cond.unsqueeze(1)
|
| 79 |
+
|
| 80 |
+
batch_size, n_input_views = images.shape[:2]
|
| 81 |
+
images = (images - self.image_mean) / self.image_std
|
| 82 |
+
out = model(
|
| 83 |
+
rearrange(images, "B N C H W -> (B N) C H W"),
|
| 84 |
+
modulation_cond=(
|
| 85 |
+
rearrange(modulation_cond, "B N Cc -> (B N) Cc")
|
| 86 |
+
if modulation_cond is not None
|
| 87 |
+
else None
|
| 88 |
+
),
|
| 89 |
+
)
|
| 90 |
+
local_features = out.last_hidden_state
|
| 91 |
+
local_features = local_features.permute(0, 2, 1)
|
| 92 |
+
local_features = rearrange(
|
| 93 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
| 94 |
+
)
|
| 95 |
+
if packed:
|
| 96 |
+
local_features = local_features.squeeze(1)
|
| 97 |
+
|
| 98 |
+
return local_features
|
| 99 |
+
|
| 100 |
+
def detokenize(self, *args, **kwargs):
|
| 101 |
+
raise NotImplementedError
|
sf3d/models/tokenizers/triplane.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
from jaxtyping import Float
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from sf3d.models.utils import BaseModule
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TriplaneLearnablePositionalEmbedding(BaseModule):
|
| 14 |
+
@dataclass
|
| 15 |
+
class Config(BaseModule.Config):
|
| 16 |
+
plane_size: int = 96
|
| 17 |
+
num_channels: int = 1024
|
| 18 |
+
|
| 19 |
+
cfg: Config
|
| 20 |
+
|
| 21 |
+
def configure(self) -> None:
|
| 22 |
+
self.embeddings = nn.Parameter(
|
| 23 |
+
torch.randn(
|
| 24 |
+
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
|
| 25 |
+
dtype=torch.float32,
|
| 26 |
+
)
|
| 27 |
+
* 1
|
| 28 |
+
/ math.sqrt(self.cfg.num_channels)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
|
| 32 |
+
return rearrange(
|
| 33 |
+
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
|
| 34 |
+
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def detokenize(
|
| 38 |
+
self, tokens: Float[Tensor, "B Ct Nt"]
|
| 39 |
+
) -> Float[Tensor, "B 3 Ct Hp Wp"]:
|
| 40 |
+
batch_size, Ct, Nt = tokens.shape
|
| 41 |
+
assert Nt == self.cfg.plane_size**2 * 3
|
| 42 |
+
assert Ct == self.cfg.num_channels
|
| 43 |
+
return rearrange(
|
| 44 |
+
tokens,
|
| 45 |
+
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
|
| 46 |
+
Np=3,
|
| 47 |
+
Hp=self.cfg.plane_size,
|
| 48 |
+
Wp=self.cfg.plane_size,
|
| 49 |
+
)
|
sf3d/models/transformers/attention.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Modulation(nn.Module):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
embedding_dim: int,
|
| 9 |
+
condition_dim: int,
|
| 10 |
+
zero_init: bool = False,
|
| 11 |
+
single_layer: bool = False,
|
| 12 |
+
):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.silu = nn.SiLU()
|
| 15 |
+
if single_layer:
|
| 16 |
+
self.linear1 = nn.Identity()
|
| 17 |
+
else:
|
| 18 |
+
self.linear1 = nn.Linear(condition_dim, condition_dim)
|
| 19 |
+
|
| 20 |
+
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
|
| 21 |
+
|
| 22 |
+
# Only zero init the last linear layer
|
| 23 |
+
if zero_init:
|
| 24 |
+
nn.init.zeros_(self.linear2.weight)
|
| 25 |
+
nn.init.zeros_(self.linear2.bias)
|
| 26 |
+
|
| 27 |
+
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
emb = self.linear2(self.silu(self.linear1(condition)))
|
| 29 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 30 |
+
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 31 |
+
return x
|
sf3d/models/transformers/backbone.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from sf3d.models.utils import BaseModule
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GEGLU(nn.Module):
|
| 12 |
+
r"""
|
| 13 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
dim_in (`int`): The number of channels in the input.
|
| 17 |
+
dim_out (`int`): The number of channels in the output.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, dim_in: int, dim_out: int):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 23 |
+
|
| 24 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
if gate.device.type != "mps":
|
| 26 |
+
return F.gelu(gate)
|
| 27 |
+
# mps: gelu is not implemented for float16
|
| 28 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
| 29 |
+
|
| 30 |
+
def forward(self, hidden_states, scale: float = 1.0):
|
| 31 |
+
args = ()
|
| 32 |
+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
| 33 |
+
return hidden_states * self.gelu(gate)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CrossAttention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim,
|
| 40 |
+
kv_dim=None,
|
| 41 |
+
num_heads=16,
|
| 42 |
+
qkv_bias=False,
|
| 43 |
+
attn_drop=0.0,
|
| 44 |
+
proj_drop=0.0,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
self.scale = head_dim**-0.5
|
| 50 |
+
kv_dim = dim if not kv_dim else kv_dim
|
| 51 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
| 52 |
+
self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
|
| 53 |
+
self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
|
| 54 |
+
self.attn_drop = attn_drop
|
| 55 |
+
self.proj = nn.Linear(dim, dim)
|
| 56 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 57 |
+
|
| 58 |
+
def forward(self, x_q, x_kv):
|
| 59 |
+
B, N_q, C = x_q.shape
|
| 60 |
+
B, N_kv, _ = x_kv.shape
|
| 61 |
+
# [B, N_q, C] -> [B, N_q, H, C/H]
|
| 62 |
+
q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
|
| 63 |
+
# [B, N_kv, C] -> [B, N_kv, H, C/H]
|
| 64 |
+
k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
|
| 65 |
+
v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
|
| 66 |
+
|
| 67 |
+
# attention
|
| 68 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 69 |
+
q.permute(0, 2, 1, 3),
|
| 70 |
+
k.permute(0, 2, 1, 3),
|
| 71 |
+
v.permute(0, 2, 1, 3),
|
| 72 |
+
attn_mask=None,
|
| 73 |
+
dropout_p=self.attn_drop,
|
| 74 |
+
scale=self.scale,
|
| 75 |
+
).permute(0, 2, 1, 3)
|
| 76 |
+
|
| 77 |
+
# [B, N_q, H, C/H] -> [B, N_q, C]
|
| 78 |
+
x = x.reshape(B, N_q, C)
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class FeedForward(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
dim: int,
|
| 88 |
+
dim_out: Optional[int] = None,
|
| 89 |
+
mult: int = 4,
|
| 90 |
+
dropout: float = 0.0,
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
inner_dim = int(dim * mult)
|
| 94 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 95 |
+
act_fn = GEGLU(dim, inner_dim)
|
| 96 |
+
self.net = nn.ModuleList([])
|
| 97 |
+
self.net.append(act_fn)
|
| 98 |
+
self.net.append(nn.Dropout(dropout))
|
| 99 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
| 100 |
+
|
| 101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 102 |
+
for module in self.net:
|
| 103 |
+
x = module(x)
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class BasicBlock(nn.Module):
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
dim: int,
|
| 111 |
+
kv_dim: Optional[int] = None,
|
| 112 |
+
num_heads: int = 16,
|
| 113 |
+
qkv_bias: bool = False,
|
| 114 |
+
attn_drop: float = 0.0,
|
| 115 |
+
proj_drop: float = 0.0,
|
| 116 |
+
ff_drop: float = 0.0,
|
| 117 |
+
):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 120 |
+
self.attn1 = CrossAttention(
|
| 121 |
+
dim,
|
| 122 |
+
kv_dim=dim,
|
| 123 |
+
num_heads=num_heads,
|
| 124 |
+
qkv_bias=qkv_bias,
|
| 125 |
+
attn_drop=attn_drop,
|
| 126 |
+
proj_drop=proj_drop,
|
| 127 |
+
)
|
| 128 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 129 |
+
self.attn2 = CrossAttention(
|
| 130 |
+
dim,
|
| 131 |
+
kv_dim=kv_dim,
|
| 132 |
+
num_heads=num_heads,
|
| 133 |
+
qkv_bias=qkv_bias,
|
| 134 |
+
attn_drop=attn_drop,
|
| 135 |
+
proj_drop=proj_drop,
|
| 136 |
+
)
|
| 137 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 138 |
+
self.ff = FeedForward(dim, dropout=ff_drop)
|
| 139 |
+
|
| 140 |
+
def forward(self, z, x):
|
| 141 |
+
z_norm = self.norm1(z)
|
| 142 |
+
z = z + self.attn1(z_norm, z_norm)
|
| 143 |
+
# TODO: do we need to have the second attention when x is None?
|
| 144 |
+
z_norm = self.norm2(z)
|
| 145 |
+
z = z + self.attn2(z_norm, x if x is not None else z_norm)
|
| 146 |
+
z_norm = self.norm3(z)
|
| 147 |
+
z = z + self.ff(z_norm)
|
| 148 |
+
return z
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class SingleStreamTransformer(BaseModule):
|
| 152 |
+
@dataclass
|
| 153 |
+
class Config(BaseModule.Config):
|
| 154 |
+
num_attention_heads: int = 16
|
| 155 |
+
attention_head_dim: int = 88
|
| 156 |
+
in_channels: Optional[int] = None
|
| 157 |
+
out_channels: Optional[int] = None
|
| 158 |
+
num_layers: int = 16
|
| 159 |
+
dropout: float = 0.0
|
| 160 |
+
norm_num_groups: int = 32
|
| 161 |
+
cross_attention_dim: Optional[int] = None
|
| 162 |
+
attention_bias: bool = False
|
| 163 |
+
|
| 164 |
+
cfg: Config
|
| 165 |
+
|
| 166 |
+
def configure(self) -> None:
|
| 167 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
| 168 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
| 169 |
+
inner_dim = self.num_attention_heads * self.attention_head_dim
|
| 170 |
+
|
| 171 |
+
# Define input layers
|
| 172 |
+
self.norm = torch.nn.GroupNorm(
|
| 173 |
+
num_groups=self.cfg.norm_num_groups,
|
| 174 |
+
num_channels=self.cfg.in_channels,
|
| 175 |
+
eps=1e-6,
|
| 176 |
+
affine=True,
|
| 177 |
+
)
|
| 178 |
+
self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim)
|
| 179 |
+
|
| 180 |
+
# Define transformers blocks
|
| 181 |
+
self.transformer_blocks = nn.ModuleList(
|
| 182 |
+
[
|
| 183 |
+
BasicBlock(
|
| 184 |
+
inner_dim,
|
| 185 |
+
kv_dim=self.cfg.cross_attention_dim,
|
| 186 |
+
num_heads=self.num_attention_heads,
|
| 187 |
+
qkv_bias=self.cfg.attention_bias,
|
| 188 |
+
proj_drop=self.cfg.dropout,
|
| 189 |
+
ff_drop=self.cfg.dropout,
|
| 190 |
+
)
|
| 191 |
+
for d in range(self.cfg.num_layers)
|
| 192 |
+
]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# 4. Define output layers
|
| 196 |
+
self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels)
|
| 197 |
+
|
| 198 |
+
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
| 199 |
+
residual = hidden_states
|
| 200 |
+
hidden_states = self.norm(hidden_states)
|
| 201 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
| 202 |
+
hidden_states = self.proj_in(hidden_states)
|
| 203 |
+
for block in self.transformer_blocks:
|
| 204 |
+
hidden_states = block(hidden_states, encoder_hidden_states)
|
| 205 |
+
hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous()
|
| 206 |
+
# TODO: do we really need to add the residual?
|
| 207 |
+
hidden_states = hidden_states + residual
|
| 208 |
+
return hidden_states
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class FuseBlock(nn.Module):
|
| 212 |
+
"""
|
| 213 |
+
Fuse X in to Z with cross attention
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
dim_z: int,
|
| 219 |
+
dim_x: int,
|
| 220 |
+
num_heads: int = 16,
|
| 221 |
+
qkv_bias: bool = False,
|
| 222 |
+
attn_drop: float = 0.0,
|
| 223 |
+
proj_drop: float = 0.0,
|
| 224 |
+
ff_drop: float = 0.0,
|
| 225 |
+
norm_x_input: bool = True,
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.norm_x_input = norm_x_input
|
| 229 |
+
if self.norm_x_input:
|
| 230 |
+
self.norm_x = nn.LayerNorm(dim_x)
|
| 231 |
+
self.attn = CrossAttention(
|
| 232 |
+
dim_z,
|
| 233 |
+
kv_dim=dim_x,
|
| 234 |
+
num_heads=num_heads,
|
| 235 |
+
qkv_bias=qkv_bias,
|
| 236 |
+
attn_drop=attn_drop,
|
| 237 |
+
proj_drop=proj_drop,
|
| 238 |
+
)
|
| 239 |
+
self.norm_z1 = nn.LayerNorm(dim_z)
|
| 240 |
+
self.norm_z2 = nn.LayerNorm(dim_z)
|
| 241 |
+
self.ff = FeedForward(dim_z, dropout=ff_drop)
|
| 242 |
+
|
| 243 |
+
def forward(self, z, x):
|
| 244 |
+
# TODO: do we need to normalize x?
|
| 245 |
+
z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x)
|
| 246 |
+
z = z + self.ff(self.norm_z2(z))
|
| 247 |
+
return z
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@torch.no_grad()
|
| 251 |
+
def get_triplane_attention_mask(res):
|
| 252 |
+
N = 3 * res * res
|
| 253 |
+
attn_mask = torch.zeros(3, res, res, 3, res, res)
|
| 254 |
+
|
| 255 |
+
i, j = torch.meshgrid(torch.arange(res), torch.arange(res))
|
| 256 |
+
|
| 257 |
+
attn_mask[0, i, j, 1, i, :] = 1.0
|
| 258 |
+
attn_mask[0, i, j, 2, j, :] = 1.0
|
| 259 |
+
attn_mask[1, i, j, 0, i, :] = 1.0
|
| 260 |
+
attn_mask[1, i, j, 2, :, j] = 1.0
|
| 261 |
+
attn_mask[2, i, j, 0, :, i] = 1.0
|
| 262 |
+
attn_mask[2, i, j, 1, :, j] = 1.0
|
| 263 |
+
attn_mask = attn_mask.bool()
|
| 264 |
+
|
| 265 |
+
attn_bias = torch.empty_like(attn_mask, dtype=torch.float)
|
| 266 |
+
attn_bias.masked_fill_(attn_mask, 0.0)
|
| 267 |
+
attn_bias.masked_fill_(~attn_mask, float("-inf"))
|
| 268 |
+
|
| 269 |
+
return attn_bias.reshape(N, N)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class TriplaneAttention(nn.Module):
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
dim: int,
|
| 276 |
+
resolution: int,
|
| 277 |
+
num_heads: int = 16,
|
| 278 |
+
qkv_bias: bool = False,
|
| 279 |
+
attn_drop: float = 0.0,
|
| 280 |
+
proj_drop: float = 0.0,
|
| 281 |
+
full_attention: bool = False,
|
| 282 |
+
):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.num_heads = num_heads
|
| 285 |
+
head_dim = dim // num_heads
|
| 286 |
+
self.scale = head_dim**-0.5
|
| 287 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
| 288 |
+
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
| 289 |
+
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
|
| 290 |
+
self.attn_drop = attn_drop
|
| 291 |
+
self.proj = nn.Linear(dim, dim)
|
| 292 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 293 |
+
|
| 294 |
+
self.resolution = resolution
|
| 295 |
+
self.full_attention = full_attention
|
| 296 |
+
self.attn_mask = (
|
| 297 |
+
get_triplane_attention_mask(resolution) if not full_attention else None
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def forward(self, x):
|
| 301 |
+
B, N, C = x.shape
|
| 302 |
+
# [B, N, C] -> [B, N, H, C/H]
|
| 303 |
+
q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
| 304 |
+
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
| 305 |
+
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
| 306 |
+
|
| 307 |
+
# detokenize the planes
|
| 308 |
+
assert N == self.resolution**2 * 3
|
| 309 |
+
attn_bias = (
|
| 310 |
+
self.attn_mask.to(q)
|
| 311 |
+
.unsqueeze(0)
|
| 312 |
+
.unsqueeze(0)
|
| 313 |
+
.expand(B, self.num_heads, -1, -1)
|
| 314 |
+
if not self.full_attention
|
| 315 |
+
else None
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# full attention
|
| 319 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 320 |
+
q.permute(0, 2, 1, 3),
|
| 321 |
+
k.permute(0, 2, 1, 3),
|
| 322 |
+
v.permute(0, 2, 1, 3),
|
| 323 |
+
attn_mask=attn_bias,
|
| 324 |
+
dropout_p=self.attn_drop,
|
| 325 |
+
scale=self.scale,
|
| 326 |
+
).permute(0, 2, 1, 3)
|
| 327 |
+
|
| 328 |
+
# [B, N_q, H, C/H] -> [B, N_q, C]
|
| 329 |
+
x = x.reshape(B, N, C)
|
| 330 |
+
x = self.proj(x)
|
| 331 |
+
x = self.proj_drop(x)
|
| 332 |
+
return x
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class TwoStreamBlock(nn.Module):
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
dim_latent: int,
|
| 339 |
+
dim_input: int,
|
| 340 |
+
num_basic_blocks: int = 4,
|
| 341 |
+
num_heads: int = 16,
|
| 342 |
+
qkv_bias: bool = False,
|
| 343 |
+
attn_drop: float = 0.0,
|
| 344 |
+
proj_drop: float = 0.0,
|
| 345 |
+
ff_drop: float = 0.0,
|
| 346 |
+
norm_x_input: bool = True,
|
| 347 |
+
dim_cross: Optional[int] = None,
|
| 348 |
+
):
|
| 349 |
+
super().__init__()
|
| 350 |
+
|
| 351 |
+
# Define the fuse block that fuse the input into the latent
|
| 352 |
+
self.fuse_block_in = FuseBlock(
|
| 353 |
+
dim_latent,
|
| 354 |
+
dim_input,
|
| 355 |
+
num_heads=num_heads,
|
| 356 |
+
qkv_bias=qkv_bias,
|
| 357 |
+
attn_drop=attn_drop,
|
| 358 |
+
proj_drop=proj_drop,
|
| 359 |
+
ff_drop=ff_drop,
|
| 360 |
+
norm_x_input=norm_x_input,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Define the transformer block that process the latent
|
| 364 |
+
self.transformer_block = nn.ModuleList(
|
| 365 |
+
[
|
| 366 |
+
BasicBlock(
|
| 367 |
+
dim_latent,
|
| 368 |
+
kv_dim=dim_cross,
|
| 369 |
+
num_heads=num_heads,
|
| 370 |
+
qkv_bias=qkv_bias,
|
| 371 |
+
proj_drop=proj_drop,
|
| 372 |
+
ff_drop=ff_drop,
|
| 373 |
+
)
|
| 374 |
+
for _ in range(num_basic_blocks)
|
| 375 |
+
]
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Define the fuse block that fuse the latent into the input
|
| 379 |
+
self.fuse_block_out = FuseBlock(
|
| 380 |
+
dim_input,
|
| 381 |
+
dim_latent,
|
| 382 |
+
num_heads=num_heads,
|
| 383 |
+
qkv_bias=qkv_bias,
|
| 384 |
+
attn_drop=attn_drop,
|
| 385 |
+
proj_drop=proj_drop,
|
| 386 |
+
ff_drop=ff_drop,
|
| 387 |
+
norm_x_input=norm_x_input,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def forward(self, latent, input, cross_input):
|
| 391 |
+
latent = self.fuse_block_in(latent, input)
|
| 392 |
+
for block in self.transformer_block:
|
| 393 |
+
latent = block(latent, cross_input)
|
| 394 |
+
input = self.fuse_block_out(input, latent)
|
| 395 |
+
return latent, input
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class TwoStreamInterleaveTransformer(BaseModule):
|
| 399 |
+
@dataclass
|
| 400 |
+
class Config(BaseModule.Config):
|
| 401 |
+
num_attention_heads: int = 16
|
| 402 |
+
attention_head_dim: int = 64
|
| 403 |
+
raw_triplane_channels: int = 1024
|
| 404 |
+
triplane_channels: int = 1024
|
| 405 |
+
raw_image_channels: int = 1024
|
| 406 |
+
num_latents: int = 1792
|
| 407 |
+
num_blocks: int = 4
|
| 408 |
+
num_basic_blocks: int = 3
|
| 409 |
+
dropout: float = 0.0
|
| 410 |
+
latent_init_std: float = 0.02
|
| 411 |
+
norm_num_groups: int = 32
|
| 412 |
+
attention_bias: bool = False
|
| 413 |
+
norm_x_input: bool = False
|
| 414 |
+
cross_attention_dim: int = 1024
|
| 415 |
+
mix_latent: bool = True
|
| 416 |
+
|
| 417 |
+
cfg: Config
|
| 418 |
+
|
| 419 |
+
def configure(self) -> None:
|
| 420 |
+
self.mix_latent = self.cfg.mix_latent
|
| 421 |
+
|
| 422 |
+
# Define the dimensions
|
| 423 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
| 424 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
| 425 |
+
self.num_latents = self.cfg.num_latents
|
| 426 |
+
self.latent_dim = self.num_attention_heads * self.attention_head_dim
|
| 427 |
+
|
| 428 |
+
# Define input layers
|
| 429 |
+
if self.cfg.norm_num_groups > 0:
|
| 430 |
+
self.norm_triplane = torch.nn.GroupNorm(
|
| 431 |
+
num_groups=self.cfg.norm_num_groups,
|
| 432 |
+
num_channels=self.cfg.raw_triplane_channels,
|
| 433 |
+
eps=1e-6,
|
| 434 |
+
affine=True,
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels)
|
| 438 |
+
self.proj_triplane = nn.Linear(
|
| 439 |
+
self.cfg.raw_triplane_channels, self.cfg.triplane_channels
|
| 440 |
+
)
|
| 441 |
+
if self.mix_latent:
|
| 442 |
+
self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels)
|
| 443 |
+
self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim)
|
| 444 |
+
self.norm_latent = nn.LayerNorm(self.latent_dim)
|
| 445 |
+
self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim)
|
| 446 |
+
|
| 447 |
+
# Define the latents
|
| 448 |
+
self.latent_init = nn.Parameter(
|
| 449 |
+
torch.zeros(1, self.num_latents, self.latent_dim)
|
| 450 |
+
)
|
| 451 |
+
nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std)
|
| 452 |
+
|
| 453 |
+
# Define the transformer blocks
|
| 454 |
+
self.main_blocks = nn.ModuleList(
|
| 455 |
+
[
|
| 456 |
+
TwoStreamBlock(
|
| 457 |
+
self.latent_dim,
|
| 458 |
+
self.cfg.triplane_channels,
|
| 459 |
+
num_basic_blocks=self.cfg.num_basic_blocks,
|
| 460 |
+
num_heads=self.num_attention_heads,
|
| 461 |
+
qkv_bias=self.cfg.attention_bias,
|
| 462 |
+
proj_drop=self.cfg.dropout,
|
| 463 |
+
ff_drop=self.cfg.dropout,
|
| 464 |
+
norm_x_input=self.cfg.norm_x_input,
|
| 465 |
+
dim_cross=self.cfg.cross_attention_dim,
|
| 466 |
+
)
|
| 467 |
+
for _ in range(self.cfg.num_blocks)
|
| 468 |
+
]
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# 4. Define output layers
|
| 472 |
+
self.proj_out = nn.Linear(
|
| 473 |
+
self.cfg.triplane_channels, self.cfg.raw_triplane_channels
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
def forward(self, hidden_states, encoder_hidden_states, **kwargs):
|
| 477 |
+
# hidden_states: [B, triplane_dim, N_triplane] is triplane tokens
|
| 478 |
+
# encoder_hidden_states: [B, N_image, image_dim] is the image tokens
|
| 479 |
+
if isinstance(self.norm_triplane, nn.GroupNorm):
|
| 480 |
+
triplane_tokens = self.norm_triplane(hidden_states)
|
| 481 |
+
triplane_tokens = triplane_tokens.permute(
|
| 482 |
+
0, 2, 1
|
| 483 |
+
) # [B, N_triplane, triplane_dim]
|
| 484 |
+
elif isinstance(self.norm_triplane, nn.LayerNorm):
|
| 485 |
+
triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1))
|
| 486 |
+
else:
|
| 487 |
+
raise ValueError("Unknown normalization layer")
|
| 488 |
+
triplane_tokens = self.proj_triplane(triplane_tokens)
|
| 489 |
+
if self.mix_latent:
|
| 490 |
+
image_tokens = self.norm_image(
|
| 491 |
+
encoder_hidden_states
|
| 492 |
+
) # [B, N_image, image_dim]
|
| 493 |
+
image_tokens = self.proj_image(image_tokens)
|
| 494 |
+
init_latents = self.latent_init.expand(
|
| 495 |
+
hidden_states.shape[0], -1, -1
|
| 496 |
+
) # [B, N_latent_init, latent_dim]
|
| 497 |
+
init_latents = self.norm_latent(init_latents)
|
| 498 |
+
init_latents = self.proj_latent(init_latents)
|
| 499 |
+
if self.mix_latent:
|
| 500 |
+
latent_tokens = torch.cat(
|
| 501 |
+
[image_tokens, init_latents], dim=1
|
| 502 |
+
) # [B, N_latent, latent_dim]
|
| 503 |
+
else:
|
| 504 |
+
latent_tokens = init_latents
|
| 505 |
+
|
| 506 |
+
# forward the main blocks
|
| 507 |
+
for block in self.main_blocks:
|
| 508 |
+
latent_tokens, triplane_tokens = block(
|
| 509 |
+
latent_tokens, triplane_tokens, encoder_hidden_states
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# project the triplane tokens back to the original dimension
|
| 513 |
+
triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous()
|
| 514 |
+
triplane_tokens = triplane_tokens + hidden_states
|
| 515 |
+
return triplane_tokens
|
sf3d/models/utils.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import importlib
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import PIL
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from jaxtyping import Float, Int, Num
|
| 12 |
+
from omegaconf import DictConfig, OmegaConf
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseModule(nn.Module):
|
| 17 |
+
@dataclass
|
| 18 |
+
class Config:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
cfg: Config # add this to every subclass of BaseModule to enable static type checking
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.cfg = parse_structured(self.Config, cfg)
|
| 28 |
+
self.configure(*args, **kwargs)
|
| 29 |
+
|
| 30 |
+
def configure(self, *args, **kwargs) -> None:
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def find_class(cls_string):
|
| 35 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
| 36 |
+
cls_name = cls_string.split(".")[-1]
|
| 37 |
+
module = importlib.import_module(module_string, package=None)
|
| 38 |
+
cls = getattr(module, cls_name)
|
| 39 |
+
return cls
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
| 43 |
+
# Check if cfg.keys are in fields
|
| 44 |
+
cfg_ = cfg.copy()
|
| 45 |
+
keys = list(cfg_.keys())
|
| 46 |
+
|
| 47 |
+
field_names = {f.name for f in dataclasses.fields(fields)}
|
| 48 |
+
for key in keys:
|
| 49 |
+
# This is helpful when swapping out modules from CLI
|
| 50 |
+
if key not in field_names:
|
| 51 |
+
print(f"Ignoring {key} as it's not supported by {fields}")
|
| 52 |
+
cfg_.pop(key)
|
| 53 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
|
| 54 |
+
return scfg
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
EPS_DTYPE = {
|
| 58 |
+
torch.float16: 1e-4,
|
| 59 |
+
torch.bfloat16: 1e-4,
|
| 60 |
+
torch.float32: 1e-7,
|
| 61 |
+
torch.float64: 1e-8,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def dot(x, y, dim=-1):
|
| 66 |
+
return torch.sum(x * y, dim, keepdim=True)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def reflect(x, n):
|
| 70 |
+
return x - 2 * dot(x, n) * n
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def normalize(x, dim=-1, eps=None):
|
| 74 |
+
if eps is None:
|
| 75 |
+
eps = EPS_DTYPE[x.dtype]
|
| 76 |
+
return F.normalize(x, dim=dim, p=2, eps=eps)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def scale_tensor(
|
| 83 |
+
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
|
| 84 |
+
):
|
| 85 |
+
if inp_scale is None:
|
| 86 |
+
inp_scale = (0, 1)
|
| 87 |
+
if tgt_scale is None:
|
| 88 |
+
tgt_scale = (0, 1)
|
| 89 |
+
if isinstance(tgt_scale, Tensor):
|
| 90 |
+
assert dat.shape[-1] == tgt_scale.shape[-1]
|
| 91 |
+
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
| 92 |
+
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
| 93 |
+
return dat
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def dilate_fill(img, mask, iterations=10):
|
| 97 |
+
oldMask = mask.float()
|
| 98 |
+
oldImg = img
|
| 99 |
+
|
| 100 |
+
mask_kernel = torch.ones(
|
| 101 |
+
(1, 1, 3, 3),
|
| 102 |
+
dtype=oldMask.dtype,
|
| 103 |
+
device=oldMask.device,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
for i in range(iterations):
|
| 107 |
+
newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
|
| 108 |
+
|
| 109 |
+
# Fill the extension with mean color of old valid regions
|
| 110 |
+
img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
|
| 111 |
+
mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
| 112 |
+
new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
| 113 |
+
|
| 114 |
+
# Average color of the valid region
|
| 115 |
+
mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
|
| 116 |
+
2
|
| 117 |
+
)
|
| 118 |
+
# Extend it to the new region
|
| 119 |
+
fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
|
| 120 |
+
|
| 121 |
+
mask_conv = F.conv2d(
|
| 122 |
+
newMask, mask_kernel, padding=1
|
| 123 |
+
) # Get the sum for each kernel patch
|
| 124 |
+
newImg = F.fold(
|
| 125 |
+
fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
|
| 126 |
+
) / mask_conv.clamp(1)
|
| 127 |
+
|
| 128 |
+
diffMask = newMask - oldMask
|
| 129 |
+
|
| 130 |
+
oldMask = newMask
|
| 131 |
+
oldImg = torch.lerp(oldImg, newImg, diffMask)
|
| 132 |
+
|
| 133 |
+
return oldImg
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def float32_to_uint8_np(
|
| 137 |
+
x: Float[np.ndarray, "*B H W C"],
|
| 138 |
+
dither: bool = True,
|
| 139 |
+
dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
|
| 140 |
+
dither_strength: float = 1.0,
|
| 141 |
+
) -> Int[np.ndarray, "*B H W C"]:
|
| 142 |
+
if dither:
|
| 143 |
+
dither = (
|
| 144 |
+
dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
|
| 145 |
+
)
|
| 146 |
+
if dither_mask is not None:
|
| 147 |
+
dither = dither * dither_mask
|
| 148 |
+
return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
|
| 149 |
+
return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def convert_data(data):
|
| 153 |
+
if data is None:
|
| 154 |
+
return None
|
| 155 |
+
elif isinstance(data, np.ndarray):
|
| 156 |
+
return data
|
| 157 |
+
elif isinstance(data, torch.Tensor):
|
| 158 |
+
if data.dtype in [torch.float16, torch.bfloat16]:
|
| 159 |
+
data = data.float()
|
| 160 |
+
return data.detach().cpu().numpy()
|
| 161 |
+
elif isinstance(data, list):
|
| 162 |
+
return [convert_data(d) for d in data]
|
| 163 |
+
elif isinstance(data, dict):
|
| 164 |
+
return {k: convert_data(v) for k, v in data.items()}
|
| 165 |
+
else:
|
| 166 |
+
raise TypeError(
|
| 167 |
+
"Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
|
| 168 |
+
type(data),
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ImageProcessor:
|
| 173 |
+
def convert_and_resize(
|
| 174 |
+
self,
|
| 175 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
| 176 |
+
size: int,
|
| 177 |
+
):
|
| 178 |
+
if isinstance(image, PIL.Image.Image):
|
| 179 |
+
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
| 180 |
+
elif isinstance(image, np.ndarray):
|
| 181 |
+
if image.dtype == np.uint8:
|
| 182 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
| 183 |
+
else:
|
| 184 |
+
image = torch.from_numpy(image)
|
| 185 |
+
elif isinstance(image, torch.Tensor):
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
batched = image.ndim == 4
|
| 189 |
+
|
| 190 |
+
if not batched:
|
| 191 |
+
image = image[None, ...]
|
| 192 |
+
image = F.interpolate(
|
| 193 |
+
image.permute(0, 3, 1, 2),
|
| 194 |
+
(size, size),
|
| 195 |
+
mode="bilinear",
|
| 196 |
+
align_corners=False,
|
| 197 |
+
antialias=True,
|
| 198 |
+
).permute(0, 2, 3, 1)
|
| 199 |
+
if not batched:
|
| 200 |
+
image = image[0]
|
| 201 |
+
return image
|
| 202 |
+
|
| 203 |
+
def __call__(
|
| 204 |
+
self,
|
| 205 |
+
image: Union[
|
| 206 |
+
PIL.Image.Image,
|
| 207 |
+
np.ndarray,
|
| 208 |
+
torch.FloatTensor,
|
| 209 |
+
List[PIL.Image.Image],
|
| 210 |
+
List[np.ndarray],
|
| 211 |
+
List[torch.FloatTensor],
|
| 212 |
+
],
|
| 213 |
+
size: int,
|
| 214 |
+
) -> Any:
|
| 215 |
+
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
| 216 |
+
image = self.convert_and_resize(image, size)
|
| 217 |
+
else:
|
| 218 |
+
if not isinstance(image, list):
|
| 219 |
+
image = [image]
|
| 220 |
+
image = [self.convert_and_resize(im, size) for im in image]
|
| 221 |
+
image = torch.stack(image, dim=0)
|
| 222 |
+
return image
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
| 226 |
+
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
| 227 |
+
intrinsic = np.identity(3, dtype=np.float32)
|
| 228 |
+
intrinsic[0, 0] = focal_length
|
| 229 |
+
intrinsic[1, 1] = focal_length
|
| 230 |
+
intrinsic[0, 2] = W / 2.0
|
| 231 |
+
intrinsic[1, 2] = H / 2.0
|
| 232 |
+
|
| 233 |
+
if bs > 0:
|
| 234 |
+
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
| 235 |
+
|
| 236 |
+
return torch.from_numpy(intrinsic)
|
sf3d/system.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from contextlib import nullcontext
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, List, Literal, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import trimesh
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
from jaxtyping import Float
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from safetensors.torch import load_model
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
|
| 18 |
+
from sf3d.models.isosurface import MarchingTetrahedraHelper
|
| 19 |
+
from sf3d.models.mesh import Mesh
|
| 20 |
+
from sf3d.models.utils import (
|
| 21 |
+
BaseModule,
|
| 22 |
+
ImageProcessor,
|
| 23 |
+
convert_data,
|
| 24 |
+
dilate_fill,
|
| 25 |
+
find_class,
|
| 26 |
+
float32_to_uint8_np,
|
| 27 |
+
normalize,
|
| 28 |
+
scale_tensor,
|
| 29 |
+
)
|
| 30 |
+
from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w, get_device
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from texture_baker import TextureBaker
|
| 34 |
+
except ImportError:
|
| 35 |
+
import logging
|
| 36 |
+
|
| 37 |
+
logging.warning(
|
| 38 |
+
"Could not import texture_baker. Please install it via `pip install texture-baker/`"
|
| 39 |
+
)
|
| 40 |
+
# Exit early to avoid further errors
|
| 41 |
+
raise ImportError("texture_baker not found")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SF3D(BaseModule):
|
| 45 |
+
@dataclass
|
| 46 |
+
class Config(BaseModule.Config):
|
| 47 |
+
cond_image_size: int
|
| 48 |
+
isosurface_resolution: int
|
| 49 |
+
isosurface_threshold: float = 10.0
|
| 50 |
+
radius: float = 1.0
|
| 51 |
+
background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
|
| 52 |
+
default_fovy_deg: float = 40.0
|
| 53 |
+
default_distance: float = 1.6
|
| 54 |
+
|
| 55 |
+
camera_embedder_cls: str = ""
|
| 56 |
+
camera_embedder: dict = field(default_factory=dict)
|
| 57 |
+
|
| 58 |
+
image_tokenizer_cls: str = ""
|
| 59 |
+
image_tokenizer: dict = field(default_factory=dict)
|
| 60 |
+
|
| 61 |
+
tokenizer_cls: str = ""
|
| 62 |
+
tokenizer: dict = field(default_factory=dict)
|
| 63 |
+
|
| 64 |
+
backbone_cls: str = ""
|
| 65 |
+
backbone: dict = field(default_factory=dict)
|
| 66 |
+
|
| 67 |
+
post_processor_cls: str = ""
|
| 68 |
+
post_processor: dict = field(default_factory=dict)
|
| 69 |
+
|
| 70 |
+
decoder_cls: str = ""
|
| 71 |
+
decoder: dict = field(default_factory=dict)
|
| 72 |
+
|
| 73 |
+
image_estimator_cls: str = ""
|
| 74 |
+
image_estimator: dict = field(default_factory=dict)
|
| 75 |
+
|
| 76 |
+
global_estimator_cls: str = ""
|
| 77 |
+
global_estimator: dict = field(default_factory=dict)
|
| 78 |
+
|
| 79 |
+
cfg: Config
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def from_pretrained(
|
| 83 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
| 84 |
+
):
|
| 85 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 86 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
| 87 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
| 88 |
+
else:
|
| 89 |
+
config_path = hf_hub_download(
|
| 90 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
| 91 |
+
)
|
| 92 |
+
weight_path = hf_hub_download(
|
| 93 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
cfg = OmegaConf.load(config_path)
|
| 97 |
+
OmegaConf.resolve(cfg)
|
| 98 |
+
model = cls(cfg)
|
| 99 |
+
load_model(model, weight_path)
|
| 100 |
+
return model
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def device(self):
|
| 104 |
+
return next(self.parameters()).device
|
| 105 |
+
|
| 106 |
+
def configure(self):
|
| 107 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
| 108 |
+
self.cfg.image_tokenizer
|
| 109 |
+
)
|
| 110 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
| 111 |
+
self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
|
| 112 |
+
self.cfg.camera_embedder
|
| 113 |
+
)
|
| 114 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
| 115 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
| 116 |
+
self.cfg.post_processor
|
| 117 |
+
)
|
| 118 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
| 119 |
+
self.image_estimator = find_class(self.cfg.image_estimator_cls)(
|
| 120 |
+
self.cfg.image_estimator
|
| 121 |
+
)
|
| 122 |
+
self.global_estimator = find_class(self.cfg.global_estimator_cls)(
|
| 123 |
+
self.cfg.global_estimator
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.bbox: Float[Tensor, "2 3"]
|
| 127 |
+
self.register_buffer(
|
| 128 |
+
"bbox",
|
| 129 |
+
torch.as_tensor(
|
| 130 |
+
[
|
| 131 |
+
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
|
| 132 |
+
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
|
| 133 |
+
],
|
| 134 |
+
dtype=torch.float32,
|
| 135 |
+
),
|
| 136 |
+
)
|
| 137 |
+
self.isosurface_helper = MarchingTetrahedraHelper(
|
| 138 |
+
self.cfg.isosurface_resolution,
|
| 139 |
+
os.path.join(
|
| 140 |
+
os.path.dirname(__file__),
|
| 141 |
+
"..",
|
| 142 |
+
"load",
|
| 143 |
+
"tets",
|
| 144 |
+
f"{self.cfg.isosurface_resolution}_tets.npz",
|
| 145 |
+
),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.baker = TextureBaker()
|
| 149 |
+
self.image_processor = ImageProcessor()
|
| 150 |
+
|
| 151 |
+
def triplane_to_meshes(
|
| 152 |
+
self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
|
| 153 |
+
) -> list[Mesh]:
|
| 154 |
+
meshes = []
|
| 155 |
+
for i in range(triplanes.shape[0]):
|
| 156 |
+
triplane = triplanes[i]
|
| 157 |
+
grid_vertices = scale_tensor(
|
| 158 |
+
self.isosurface_helper.grid_vertices.to(triplanes.device),
|
| 159 |
+
self.isosurface_helper.points_range,
|
| 160 |
+
self.bbox,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
values = self.query_triplane(grid_vertices, triplane)
|
| 164 |
+
decoded = self.decoder(values, include=["vertex_offset", "density"])
|
| 165 |
+
sdf = decoded["density"] - self.cfg.isosurface_threshold
|
| 166 |
+
|
| 167 |
+
deform = decoded["vertex_offset"].squeeze(0)
|
| 168 |
+
|
| 169 |
+
mesh: Mesh = self.isosurface_helper(
|
| 170 |
+
sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
|
| 171 |
+
)
|
| 172 |
+
mesh.v_pos = scale_tensor(
|
| 173 |
+
mesh.v_pos, self.isosurface_helper.points_range, self.bbox
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
meshes.append(mesh)
|
| 177 |
+
|
| 178 |
+
return meshes
|
| 179 |
+
|
| 180 |
+
def query_triplane(
|
| 181 |
+
self,
|
| 182 |
+
positions: Float[Tensor, "*B N 3"],
|
| 183 |
+
triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
|
| 184 |
+
) -> Float[Tensor, "*B N F"]:
|
| 185 |
+
batched = positions.ndim == 3
|
| 186 |
+
if not batched:
|
| 187 |
+
# no batch dimension
|
| 188 |
+
triplanes = triplanes[None, ...]
|
| 189 |
+
positions = positions[None, ...]
|
| 190 |
+
assert triplanes.ndim == 5 and positions.ndim == 3
|
| 191 |
+
|
| 192 |
+
positions = scale_tensor(
|
| 193 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
|
| 197 |
+
(positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
|
| 198 |
+
dim=-3,
|
| 199 |
+
).to(triplanes.dtype)
|
| 200 |
+
out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
|
| 201 |
+
rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
|
| 202 |
+
rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
|
| 203 |
+
align_corners=True,
|
| 204 |
+
mode="bilinear",
|
| 205 |
+
)
|
| 206 |
+
out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
|
| 207 |
+
|
| 208 |
+
return out
|
| 209 |
+
|
| 210 |
+
def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
|
| 211 |
+
# if batch[rgb_cond] is only one view, add a view dimension
|
| 212 |
+
if len(batch["rgb_cond"].shape) == 4:
|
| 213 |
+
batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
|
| 214 |
+
batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
|
| 215 |
+
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
|
| 216 |
+
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
|
| 217 |
+
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
|
| 218 |
+
|
| 219 |
+
batch_size, n_input_views = batch["rgb_cond"].shape[:2]
|
| 220 |
+
|
| 221 |
+
camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
|
| 222 |
+
camera_embeds = self.camera_embedder(**batch)
|
| 223 |
+
|
| 224 |
+
input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
|
| 225 |
+
rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
|
| 226 |
+
modulation_cond=camera_embeds,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
input_image_tokens = rearrange(
|
| 230 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
|
| 234 |
+
|
| 235 |
+
tokens = self.backbone(
|
| 236 |
+
tokens,
|
| 237 |
+
encoder_hidden_states=input_image_tokens,
|
| 238 |
+
modulation_cond=None,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
direct_codes = self.tokenizer.detokenize(tokens)
|
| 242 |
+
scene_codes = self.post_processor(direct_codes)
|
| 243 |
+
return scene_codes, direct_codes
|
| 244 |
+
|
| 245 |
+
def run_image(
|
| 246 |
+
self,
|
| 247 |
+
image: Union[Image.Image, List[Image.Image]],
|
| 248 |
+
bake_resolution: int,
|
| 249 |
+
remesh: Literal["none", "triangle", "quad"] = "none",
|
| 250 |
+
vertex_count: int = -1,
|
| 251 |
+
estimate_illumination: bool = False,
|
| 252 |
+
) -> Tuple[Union[trimesh.Trimesh, List[trimesh.Trimesh]], dict[str, Any]]:
|
| 253 |
+
if isinstance(image, list):
|
| 254 |
+
rgb_cond = []
|
| 255 |
+
mask_cond = []
|
| 256 |
+
for img in image:
|
| 257 |
+
mask, rgb = self.prepare_image(img)
|
| 258 |
+
mask_cond.append(mask)
|
| 259 |
+
rgb_cond.append(rgb)
|
| 260 |
+
rgb_cond = torch.stack(rgb_cond, 0)
|
| 261 |
+
mask_cond = torch.stack(mask_cond, 0)
|
| 262 |
+
batch_size = rgb_cond.shape[0]
|
| 263 |
+
else:
|
| 264 |
+
mask_cond, rgb_cond = self.prepare_image(image)
|
| 265 |
+
batch_size = 1
|
| 266 |
+
|
| 267 |
+
c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
|
| 268 |
+
intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
|
| 269 |
+
self.cfg.default_fovy_deg,
|
| 270 |
+
self.cfg.cond_image_size,
|
| 271 |
+
self.cfg.cond_image_size,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
batch = {
|
| 275 |
+
"rgb_cond": rgb_cond,
|
| 276 |
+
"mask_cond": mask_cond,
|
| 277 |
+
"c2w_cond": c2w_cond.view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1),
|
| 278 |
+
"intrinsic_cond": intrinsic.to(self.device)
|
| 279 |
+
.view(1, 1, 3, 3)
|
| 280 |
+
.repeat(batch_size, 1, 1, 1),
|
| 281 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.to(self.device)
|
| 282 |
+
.view(1, 1, 3, 3)
|
| 283 |
+
.repeat(batch_size, 1, 1, 1),
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
meshes, global_dict = self.generate_mesh(
|
| 287 |
+
batch, bake_resolution, remesh, vertex_count, estimate_illumination
|
| 288 |
+
)
|
| 289 |
+
if batch_size == 1:
|
| 290 |
+
return meshes[0], global_dict
|
| 291 |
+
else:
|
| 292 |
+
return meshes, global_dict
|
| 293 |
+
|
| 294 |
+
def prepare_image(self, image):
|
| 295 |
+
if image.mode != "RGBA":
|
| 296 |
+
raise ValueError("Image must be in RGBA mode")
|
| 297 |
+
img_cond = (
|
| 298 |
+
torch.from_numpy(
|
| 299 |
+
np.asarray(
|
| 300 |
+
image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
|
| 301 |
+
).astype(np.float32)
|
| 302 |
+
/ 255.0
|
| 303 |
+
)
|
| 304 |
+
.float()
|
| 305 |
+
.clip(0, 1)
|
| 306 |
+
.to(self.device)
|
| 307 |
+
)
|
| 308 |
+
mask_cond = img_cond[:, :, -1:]
|
| 309 |
+
rgb_cond = torch.lerp(
|
| 310 |
+
torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
|
| 311 |
+
img_cond[:, :, :3],
|
| 312 |
+
mask_cond,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return mask_cond, rgb_cond
|
| 316 |
+
|
| 317 |
+
def generate_mesh(
|
| 318 |
+
self,
|
| 319 |
+
batch,
|
| 320 |
+
bake_resolution: int,
|
| 321 |
+
remesh: Literal["none", "triangle", "quad"] = "none",
|
| 322 |
+
vertex_count: int = -1,
|
| 323 |
+
estimate_illumination: bool = False,
|
| 324 |
+
) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
|
| 325 |
+
batch["rgb_cond"] = self.image_processor(
|
| 326 |
+
batch["rgb_cond"], self.cfg.cond_image_size
|
| 327 |
+
)
|
| 328 |
+
batch["mask_cond"] = self.image_processor(
|
| 329 |
+
batch["mask_cond"], self.cfg.cond_image_size
|
| 330 |
+
)
|
| 331 |
+
scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
|
| 332 |
+
|
| 333 |
+
global_dict = {}
|
| 334 |
+
if self.image_estimator is not None:
|
| 335 |
+
global_dict.update(
|
| 336 |
+
self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
|
| 337 |
+
)
|
| 338 |
+
if self.global_estimator is not None and estimate_illumination:
|
| 339 |
+
global_dict.update(self.global_estimator(non_postprocessed_codes))
|
| 340 |
+
|
| 341 |
+
device = get_device()
|
| 342 |
+
with torch.no_grad():
|
| 343 |
+
with (
|
| 344 |
+
torch.autocast(device_type=device, enabled=False)
|
| 345 |
+
if "cuda" in device
|
| 346 |
+
else nullcontext()
|
| 347 |
+
):
|
| 348 |
+
meshes = self.triplane_to_meshes(scene_codes)
|
| 349 |
+
|
| 350 |
+
rets = []
|
| 351 |
+
for i, mesh in enumerate(meshes):
|
| 352 |
+
# Check for empty mesh
|
| 353 |
+
if mesh.v_pos.shape[0] == 0:
|
| 354 |
+
rets.append(trimesh.Trimesh())
|
| 355 |
+
continue
|
| 356 |
+
|
| 357 |
+
if remesh == "triangle":
|
| 358 |
+
mesh = mesh.triangle_remesh(triangle_vertex_count=vertex_count)
|
| 359 |
+
elif remesh == "quad":
|
| 360 |
+
mesh = mesh.quad_remesh(quad_vertex_count=vertex_count)
|
| 361 |
+
else:
|
| 362 |
+
if vertex_count > 0:
|
| 363 |
+
print(
|
| 364 |
+
"Warning: vertex_count is ignored when remesh is none"
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
print("After Remesh", mesh.v_pos.shape[0], mesh.t_pos_idx.shape[0])
|
| 368 |
+
mesh.unwrap_uv()
|
| 369 |
+
|
| 370 |
+
# Build textures
|
| 371 |
+
rast = self.baker.rasterize(
|
| 372 |
+
mesh.v_tex, mesh.t_pos_idx, bake_resolution
|
| 373 |
+
)
|
| 374 |
+
bake_mask = self.baker.get_mask(rast)
|
| 375 |
+
|
| 376 |
+
pos_bake = self.baker.interpolate(
|
| 377 |
+
mesh.v_pos,
|
| 378 |
+
rast,
|
| 379 |
+
mesh.t_pos_idx,
|
| 380 |
+
)
|
| 381 |
+
gb_pos = pos_bake[bake_mask]
|
| 382 |
+
|
| 383 |
+
tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
|
| 384 |
+
decoded = self.decoder(
|
| 385 |
+
tri_query, exclude=["density", "vertex_offset"]
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
nrm = self.baker.interpolate(
|
| 389 |
+
mesh.v_nrm,
|
| 390 |
+
rast,
|
| 391 |
+
mesh.t_pos_idx,
|
| 392 |
+
)
|
| 393 |
+
gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
|
| 394 |
+
decoded["normal"] = gb_nrm
|
| 395 |
+
|
| 396 |
+
# Check if any keys in global_dict start with decoded_
|
| 397 |
+
for k, v in global_dict.items():
|
| 398 |
+
if k.startswith("decoder_"):
|
| 399 |
+
decoded[k.replace("decoder_", "")] = v[i]
|
| 400 |
+
|
| 401 |
+
mat_out = {
|
| 402 |
+
"albedo": decoded["features"],
|
| 403 |
+
"roughness": decoded["roughness"],
|
| 404 |
+
"metallic": decoded["metallic"],
|
| 405 |
+
"normal": normalize(decoded["perturb_normal"]),
|
| 406 |
+
"bump": None,
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
for k, v in mat_out.items():
|
| 410 |
+
if v is None:
|
| 411 |
+
continue
|
| 412 |
+
if v.shape[0] == 1:
|
| 413 |
+
# Skip and directly add a single value
|
| 414 |
+
mat_out[k] = v[0]
|
| 415 |
+
else:
|
| 416 |
+
f = torch.zeros(
|
| 417 |
+
bake_resolution,
|
| 418 |
+
bake_resolution,
|
| 419 |
+
v.shape[-1],
|
| 420 |
+
dtype=v.dtype,
|
| 421 |
+
device=v.device,
|
| 422 |
+
)
|
| 423 |
+
if v.shape == f.shape:
|
| 424 |
+
continue
|
| 425 |
+
if k == "normal":
|
| 426 |
+
# Use un-normalized tangents here so that larger smaller tris
|
| 427 |
+
# Don't effect the tangents that much
|
| 428 |
+
tng = self.baker.interpolate(
|
| 429 |
+
mesh.v_tng,
|
| 430 |
+
rast,
|
| 431 |
+
mesh.t_pos_idx,
|
| 432 |
+
)
|
| 433 |
+
gb_tng = tng[bake_mask]
|
| 434 |
+
gb_tng = F.normalize(gb_tng, dim=-1)
|
| 435 |
+
gb_btng = F.normalize(
|
| 436 |
+
torch.cross(gb_nrm, gb_tng, dim=-1), dim=-1
|
| 437 |
+
)
|
| 438 |
+
normal = F.normalize(mat_out["normal"], dim=-1)
|
| 439 |
+
|
| 440 |
+
# Create tangent space matrix and transform normal
|
| 441 |
+
tangent_matrix = torch.stack(
|
| 442 |
+
[gb_tng, gb_btng, gb_nrm], dim=-1
|
| 443 |
+
)
|
| 444 |
+
normal_tangent = torch.bmm(
|
| 445 |
+
tangent_matrix.transpose(1, 2), normal.unsqueeze(-1)
|
| 446 |
+
).squeeze(-1)
|
| 447 |
+
|
| 448 |
+
# Convert from [-1,1] to [0,1] range for storage
|
| 449 |
+
normal_tangent = (normal_tangent * 0.5 + 0.5).clamp(
|
| 450 |
+
0, 1
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
f[bake_mask] = normal_tangent.view(-1, 3)
|
| 454 |
+
mat_out["bump"] = f
|
| 455 |
+
else:
|
| 456 |
+
f[bake_mask] = v.view(-1, v.shape[-1])
|
| 457 |
+
mat_out[k] = f
|
| 458 |
+
|
| 459 |
+
def uv_padding(arr):
|
| 460 |
+
if arr.ndim == 1:
|
| 461 |
+
return arr
|
| 462 |
+
return (
|
| 463 |
+
dilate_fill(
|
| 464 |
+
arr.permute(2, 0, 1)[None, ...].contiguous(),
|
| 465 |
+
bake_mask.unsqueeze(0).unsqueeze(0),
|
| 466 |
+
iterations=bake_resolution // 150,
|
| 467 |
+
)
|
| 468 |
+
.squeeze(0)
|
| 469 |
+
.permute(1, 2, 0)
|
| 470 |
+
.contiguous()
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
verts_np = convert_data(mesh.v_pos)
|
| 474 |
+
faces = convert_data(mesh.t_pos_idx)
|
| 475 |
+
uvs = convert_data(mesh.v_tex)
|
| 476 |
+
|
| 477 |
+
basecolor_tex = Image.fromarray(
|
| 478 |
+
float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
|
| 479 |
+
).convert("RGB")
|
| 480 |
+
basecolor_tex.format = "JPEG"
|
| 481 |
+
|
| 482 |
+
metallic = mat_out["metallic"].squeeze().cpu().item()
|
| 483 |
+
roughness = mat_out["roughness"].squeeze().cpu().item()
|
| 484 |
+
|
| 485 |
+
if "bump" in mat_out and mat_out["bump"] is not None:
|
| 486 |
+
bump_np = convert_data(uv_padding(mat_out["bump"]))
|
| 487 |
+
bump_up = np.ones_like(bump_np)
|
| 488 |
+
bump_up[..., :2] = 0.5
|
| 489 |
+
bump_up[..., 2:] = 1
|
| 490 |
+
bump_tex = Image.fromarray(
|
| 491 |
+
float32_to_uint8_np(
|
| 492 |
+
bump_np,
|
| 493 |
+
dither=True,
|
| 494 |
+
# Do not dither if something is perfectly flat
|
| 495 |
+
dither_mask=np.all(
|
| 496 |
+
bump_np == bump_up, axis=-1, keepdims=True
|
| 497 |
+
).astype(np.float32),
|
| 498 |
+
)
|
| 499 |
+
).convert("RGB")
|
| 500 |
+
bump_tex.format = (
|
| 501 |
+
"JPEG" # PNG would be better but the assets are larger
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
bump_tex = None
|
| 505 |
+
|
| 506 |
+
material = trimesh.visual.material.PBRMaterial(
|
| 507 |
+
baseColorTexture=basecolor_tex,
|
| 508 |
+
roughnessFactor=roughness,
|
| 509 |
+
metallicFactor=metallic,
|
| 510 |
+
normalTexture=bump_tex,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
tmesh = trimesh.Trimesh(
|
| 514 |
+
vertices=verts_np,
|
| 515 |
+
faces=faces,
|
| 516 |
+
visual=trimesh.visual.texture.TextureVisuals(
|
| 517 |
+
uv=uvs, material=material
|
| 518 |
+
),
|
| 519 |
+
)
|
| 520 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 521 |
+
np.radians(-90), [1, 0, 0]
|
| 522 |
+
)
|
| 523 |
+
tmesh.apply_transform(rot)
|
| 524 |
+
tmesh.apply_transform(
|
| 525 |
+
trimesh.transformations.rotation_matrix(
|
| 526 |
+
np.radians(90), [0, 1, 0]
|
| 527 |
+
)
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
tmesh.invert()
|
| 531 |
+
|
| 532 |
+
rets.append(tmesh)
|
| 533 |
+
|
| 534 |
+
return rets, global_dict
|
sf3d/utils.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import rembg
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms.functional as torchvision_F
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import sf3d.models.utils as sf3d_utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_device():
|
| 14 |
+
if os.environ.get("SF3D_USE_CPU", "0") == "1":
|
| 15 |
+
return "cpu"
|
| 16 |
+
|
| 17 |
+
device = "cpu"
|
| 18 |
+
if torch.cuda.is_available():
|
| 19 |
+
device = "cuda"
|
| 20 |
+
elif torch.backends.mps.is_available():
|
| 21 |
+
device = "mps"
|
| 22 |
+
return device
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
|
| 26 |
+
intrinsic = sf3d_utils.get_intrinsic_from_fov(
|
| 27 |
+
np.deg2rad(fov_deg),
|
| 28 |
+
H=cond_height,
|
| 29 |
+
W=cond_width,
|
| 30 |
+
)
|
| 31 |
+
intrinsic_normed_cond = intrinsic.clone()
|
| 32 |
+
intrinsic_normed_cond[..., 0, 2] /= cond_width
|
| 33 |
+
intrinsic_normed_cond[..., 1, 2] /= cond_height
|
| 34 |
+
intrinsic_normed_cond[..., 0, 0] /= cond_width
|
| 35 |
+
intrinsic_normed_cond[..., 1, 1] /= cond_height
|
| 36 |
+
|
| 37 |
+
return intrinsic, intrinsic_normed_cond
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def default_cond_c2w(distance: float):
|
| 41 |
+
c2w_cond = torch.as_tensor(
|
| 42 |
+
[
|
| 43 |
+
[0, 0, 1, distance],
|
| 44 |
+
[1, 0, 0, 0],
|
| 45 |
+
[0, 1, 0, 0],
|
| 46 |
+
[0, 0, 0, 1],
|
| 47 |
+
]
|
| 48 |
+
).float()
|
| 49 |
+
return c2w_cond
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def remove_background(
|
| 53 |
+
image: Image,
|
| 54 |
+
rembg_session: Any = None,
|
| 55 |
+
force: bool = False,
|
| 56 |
+
**rembg_kwargs,
|
| 57 |
+
) -> Image:
|
| 58 |
+
do_remove = True
|
| 59 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
| 60 |
+
do_remove = False
|
| 61 |
+
do_remove = do_remove or force
|
| 62 |
+
if do_remove:
|
| 63 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
| 64 |
+
return image
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_1d_bounds(arr):
|
| 68 |
+
nz = np.flatnonzero(arr)
|
| 69 |
+
return nz[0], nz[-1]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_bbox_from_mask(mask, thr=0.5):
|
| 73 |
+
masks_for_box = (mask > thr).astype(np.float32)
|
| 74 |
+
assert masks_for_box.sum() > 0, "Empty mask!"
|
| 75 |
+
x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
|
| 76 |
+
y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
|
| 77 |
+
return x0, y0, x1, y1
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def resize_foreground(
|
| 81 |
+
image: Union[Image.Image, np.ndarray],
|
| 82 |
+
ratio: float,
|
| 83 |
+
out_size=None,
|
| 84 |
+
) -> Image:
|
| 85 |
+
if isinstance(image, np.ndarray):
|
| 86 |
+
image = Image.fromarray(image, mode="RGBA")
|
| 87 |
+
assert image.mode == "RGBA"
|
| 88 |
+
# Get bounding box
|
| 89 |
+
mask_np = np.array(image)[:, :, -1]
|
| 90 |
+
x1, y1, x2, y2 = get_bbox_from_mask(mask_np, thr=0.5)
|
| 91 |
+
h, w = y2 - y1, x2 - x1
|
| 92 |
+
yc, xc = (y1 + y2) / 2, (x1 + x2) / 2
|
| 93 |
+
scale = max(h, w) / ratio
|
| 94 |
+
|
| 95 |
+
new_image = torchvision_F.crop(
|
| 96 |
+
image,
|
| 97 |
+
top=int(yc - scale / 2),
|
| 98 |
+
left=int(xc - scale / 2),
|
| 99 |
+
height=int(scale),
|
| 100 |
+
width=int(scale),
|
| 101 |
+
)
|
| 102 |
+
if out_size is not None:
|
| 103 |
+
new_image = new_image.resize(out_size)
|
| 104 |
+
|
| 105 |
+
return new_image
|
texture_baker/README.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Texture baker
|
| 2 |
+
|
| 3 |
+
Small texture baker which rasterizes barycentric coordinates to a tensor.
|
| 4 |
+
It also implements an interpolation module which can be used to bake attributes to textures then.
|
| 5 |
+
|
| 6 |
+
## Usage
|
| 7 |
+
|
| 8 |
+
The baker can quickly bake vertex attributes to the a texture atlas based on the UV coordinates.
|
| 9 |
+
It supports baking on the CPU and GPU.
|
| 10 |
+
|
| 11 |
+
```python
|
| 12 |
+
from texture_baker import TextureBaker
|
| 13 |
+
|
| 14 |
+
mesh = ...
|
| 15 |
+
uv = mesh.uv # num_vertex, 2
|
| 16 |
+
triangle_idx = mesh.faces # num_faces, 3
|
| 17 |
+
vertices = mesh.vertices # num_vertex, 3
|
| 18 |
+
|
| 19 |
+
tb = TextureBaker()
|
| 20 |
+
# First get the barycentric coordinates
|
| 21 |
+
rast = tb.rasterize(
|
| 22 |
+
uv=uv, face_indices=triangle_idx, bake_resolution=1024
|
| 23 |
+
)
|
| 24 |
+
# Then interpolate vertex attributes
|
| 25 |
+
position_bake = tb.interpolate(attr=vertices, rast=rast, face_indices=triangle_idx)
|
| 26 |
+
```
|
texture_baker/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
texture_baker/setup.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import platform
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from setuptools import find_packages, setup
|
| 7 |
+
from torch.utils.cpp_extension import (
|
| 8 |
+
CUDA_HOME,
|
| 9 |
+
BuildExtension,
|
| 10 |
+
CppExtension,
|
| 11 |
+
CUDAExtension,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
library_name = "texture_baker"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_extensions():
|
| 18 |
+
debug_mode = os.getenv("DEBUG", "0") == "1"
|
| 19 |
+
use_cuda = os.getenv("USE_CUDA", "1" if torch.cuda.is_available() else "0") == "1"
|
| 20 |
+
use_metal = (
|
| 21 |
+
os.getenv("USE_METAL", "1" if torch.backends.mps.is_available() else "0") == "1"
|
| 22 |
+
)
|
| 23 |
+
use_native_arch = os.getenv("USE_NATIVE_ARCH", "1") == "1"
|
| 24 |
+
if debug_mode:
|
| 25 |
+
print("Compiling in debug mode")
|
| 26 |
+
|
| 27 |
+
use_cuda = use_cuda and CUDA_HOME is not None
|
| 28 |
+
extension = CUDAExtension if use_cuda else CppExtension
|
| 29 |
+
|
| 30 |
+
is_hip_extension = (
|
| 31 |
+
True
|
| 32 |
+
if (
|
| 33 |
+
(os.environ.get("ROCM_HOME") is not None)
|
| 34 |
+
and (torch.version.hip is not None)
|
| 35 |
+
)
|
| 36 |
+
else False
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
extra_link_args = []
|
| 40 |
+
extra_compile_args = {
|
| 41 |
+
"cxx": (
|
| 42 |
+
[
|
| 43 |
+
"-O3" if not debug_mode else "-O0",
|
| 44 |
+
"-fdiagnostics-color=always",
|
| 45 |
+
"-fopenmp",
|
| 46 |
+
]
|
| 47 |
+
+ ["-march=native"]
|
| 48 |
+
if use_native_arch
|
| 49 |
+
else []
|
| 50 |
+
),
|
| 51 |
+
"nvcc": [
|
| 52 |
+
"-O3" if not debug_mode else "-O0",
|
| 53 |
+
],
|
| 54 |
+
}
|
| 55 |
+
if debug_mode:
|
| 56 |
+
extra_compile_args["cxx"].append("-g")
|
| 57 |
+
if platform.system() == "Windows":
|
| 58 |
+
extra_compile_args["cxx"].append("/Z7")
|
| 59 |
+
extra_compile_args["cxx"].append("/Od")
|
| 60 |
+
extra_link_args.extend(["/DEBUG"])
|
| 61 |
+
extra_compile_args["cxx"].append("-UNDEBUG")
|
| 62 |
+
extra_compile_args["nvcc"].append("-UNDEBUG")
|
| 63 |
+
extra_compile_args["nvcc"].append("-g")
|
| 64 |
+
extra_link_args.extend(["-O0", "-g"])
|
| 65 |
+
|
| 66 |
+
define_macros = []
|
| 67 |
+
extensions = []
|
| 68 |
+
libraries = []
|
| 69 |
+
|
| 70 |
+
this_dir = os.path.dirname(os.path.curdir)
|
| 71 |
+
sources = glob.glob(
|
| 72 |
+
os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if len(sources) == 0:
|
| 76 |
+
print("No source files found for extension, skipping extension compilation")
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
if use_cuda:
|
| 80 |
+
define_macros += [
|
| 81 |
+
("THRUST_IGNORE_CUB_VERSION_CHECK", None),
|
| 82 |
+
]
|
| 83 |
+
sources += glob.glob(
|
| 84 |
+
os.path.join(this_dir, library_name, "csrc", "**", "*.cu"), recursive=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if not is_hip_extension:
|
| 88 |
+
libraries += ["cudart", "c10_cuda"]
|
| 89 |
+
|
| 90 |
+
if use_metal:
|
| 91 |
+
define_macros += [
|
| 92 |
+
("WITH_MPS", None),
|
| 93 |
+
]
|
| 94 |
+
sources += glob.glob(
|
| 95 |
+
os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True
|
| 96 |
+
)
|
| 97 |
+
extra_compile_args.update(
|
| 98 |
+
{"cxx": ["-O3", "-arch", "arm64", "-mmacosx-version-min=10.15"]}
|
| 99 |
+
)
|
| 100 |
+
extra_link_args += ["-arch", "arm64"]
|
| 101 |
+
|
| 102 |
+
extensions.append(
|
| 103 |
+
extension(
|
| 104 |
+
name=f"{library_name}._C",
|
| 105 |
+
sources=sources,
|
| 106 |
+
define_macros=define_macros,
|
| 107 |
+
extra_compile_args=extra_compile_args,
|
| 108 |
+
extra_link_args=extra_link_args,
|
| 109 |
+
libraries=libraries
|
| 110 |
+
+ [
|
| 111 |
+
"c10",
|
| 112 |
+
"torch",
|
| 113 |
+
"torch_cpu",
|
| 114 |
+
"torch_python",
|
| 115 |
+
],
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
for ext in extensions:
|
| 120 |
+
ext.libraries = ["cudart_static" if x == "cudart" else x for x in ext.libraries]
|
| 121 |
+
|
| 122 |
+
print(extensions)
|
| 123 |
+
|
| 124 |
+
return extensions
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
setup(
|
| 128 |
+
name=library_name,
|
| 129 |
+
version="0.0.1",
|
| 130 |
+
packages=find_packages(where="."),
|
| 131 |
+
package_dir={"": "."},
|
| 132 |
+
ext_modules=get_extensions(),
|
| 133 |
+
install_requires=[],
|
| 134 |
+
package_data={
|
| 135 |
+
library_name: [os.path.join("csrc", "*.h"), os.path.join("csrc", "*.metal")],
|
| 136 |
+
},
|
| 137 |
+
description="Small texture baker which rasterizes barycentric coordinates to a tensor.",
|
| 138 |
+
long_description=open("README.md").read(),
|
| 139 |
+
long_description_content_type="text/markdown",
|
| 140 |
+
url="https://github.com/Stability-AI/texture_baker",
|
| 141 |
+
cmdclass={"build_ext": BuildExtension},
|
| 142 |
+
)
|
texture_baker/texture_baker/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch # noqa: F401
|
| 2 |
+
|
| 3 |
+
from . import _C # noqa: F401
|
| 4 |
+
from .baker import TextureBaker # noqa: F401
|
texture_baker/texture_baker/baker.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TextureBaker(nn.Module):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
def rasterize(
|
| 11 |
+
self,
|
| 12 |
+
uv: Tensor,
|
| 13 |
+
face_indices: Tensor,
|
| 14 |
+
bake_resolution: int,
|
| 15 |
+
) -> Tensor:
|
| 16 |
+
"""
|
| 17 |
+
Rasterize the UV coordinates to a barycentric coordinates
|
| 18 |
+
& Triangle idxs texture map
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
|
| 22 |
+
face_indices (Tensor, num_faces 3, int): Face indices of the mesh
|
| 23 |
+
bake_resolution (int): Resolution of the bake
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Tensor, bake_resolution bake_resolution 4, float: Rasterized map
|
| 27 |
+
"""
|
| 28 |
+
return torch.ops.texture_baker_cpp.rasterize(
|
| 29 |
+
uv, face_indices.to(torch.int32), bake_resolution
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def get_mask(self, rast: Tensor) -> Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Get the occupancy mask from the rasterized map
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tensor, bake_resolution bake_resolution, bool: Mask
|
| 41 |
+
"""
|
| 42 |
+
return rast[..., -1] >= 0
|
| 43 |
+
|
| 44 |
+
def interpolate(
|
| 45 |
+
self,
|
| 46 |
+
attr: Tensor,
|
| 47 |
+
rast: Tensor,
|
| 48 |
+
face_indices: Tensor,
|
| 49 |
+
) -> Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Interpolate the attributes using the rasterized map
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
attr (Tensor, num_vertices 3, float): Attributes of the mesh
|
| 55 |
+
rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
|
| 56 |
+
face_indices (Tensor, num_faces 3, int): Face indices of the mesh
|
| 57 |
+
uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Tensor, bake_resolution bake_resolution 3, float: Interpolated attributes
|
| 61 |
+
"""
|
| 62 |
+
return torch.ops.texture_baker_cpp.interpolate(
|
| 63 |
+
attr, face_indices.to(torch.int32), rast
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
attr: Tensor,
|
| 69 |
+
uv: Tensor,
|
| 70 |
+
face_indices: Tensor,
|
| 71 |
+
bake_resolution: int,
|
| 72 |
+
) -> Tensor:
|
| 73 |
+
"""
|
| 74 |
+
Bake the texture
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
attr (Tensor, num_vertices 3, float): Attributes of the mesh
|
| 78 |
+
uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
|
| 79 |
+
face_indices (Tensor, num_faces 3, int): Face indices of the mesh
|
| 80 |
+
bake_resolution (int): Resolution of the bake
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Tensor, bake_resolution bake_resolution 3, float: Baked texture
|
| 84 |
+
"""
|
| 85 |
+
rast = self.rasterize(uv, face_indices, bake_resolution)
|
| 86 |
+
return self.interpolate(attr, rast, face_indices, uv)
|
texture_baker/texture_baker/csrc/baker.cpp
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/ATen.h>
|
| 2 |
+
#include <ATen/Context.h>
|
| 3 |
+
#include <chrono>
|
| 4 |
+
#include <cmath>
|
| 5 |
+
#include <omp.h>
|
| 6 |
+
#include <torch/extension.h>
|
| 7 |
+
#ifndef __ARM_ARCH_ISA_A64
|
| 8 |
+
#include <immintrin.h>
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
#include "baker.h"
|
| 12 |
+
|
| 13 |
+
// #define TIMING
|
| 14 |
+
#define BINS 8
|
| 15 |
+
|
| 16 |
+
namespace texture_baker_cpp {
|
| 17 |
+
// Calculate the centroid of a triangle
|
| 18 |
+
tb_float2 triangle_centroid(const tb_float2 &v0, const tb_float2 &v1,
|
| 19 |
+
const tb_float2 &v2) {
|
| 20 |
+
return {(v0.x + v1.x + v2.x) * 0.3333f, (v0.y + v1.y + v2.y) * 0.3333f};
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
float BVH::find_best_split_plane(const BVHNode &node, int &best_axis,
|
| 24 |
+
int &best_pos, AABB ¢roidBounds) {
|
| 25 |
+
float best_cost = std::numeric_limits<float>::max();
|
| 26 |
+
|
| 27 |
+
for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
|
| 28 |
+
{
|
| 29 |
+
float boundsMin = centroidBounds.min[axis];
|
| 30 |
+
float boundsMax = centroidBounds.max[axis];
|
| 31 |
+
if (boundsMin == boundsMax) {
|
| 32 |
+
continue;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// Populate the bins
|
| 36 |
+
float scale = BINS / (boundsMax - boundsMin);
|
| 37 |
+
float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
|
| 38 |
+
int leftSum = 0, rightSum = 0;
|
| 39 |
+
|
| 40 |
+
#ifndef __ARM_ARCH_ISA_A64
|
| 41 |
+
#ifndef _MSC_VER
|
| 42 |
+
if (__builtin_cpu_supports("sse"))
|
| 43 |
+
#elif (defined(_M_AMD64) || defined(_M_X64))
|
| 44 |
+
// SSE supported on Windows
|
| 45 |
+
if constexpr (true)
|
| 46 |
+
#endif
|
| 47 |
+
{
|
| 48 |
+
__m128 min4[BINS], max4[BINS];
|
| 49 |
+
unsigned int count[BINS];
|
| 50 |
+
for (unsigned int i = 0; i < BINS; i++)
|
| 51 |
+
min4[i] = _mm_set_ps1(1e30f), max4[i] = _mm_set_ps1(-1e30f),
|
| 52 |
+
count[i] = 0;
|
| 53 |
+
for (int i = node.start; i < node.end; i++) {
|
| 54 |
+
int tri_idx = triangle_indices[i];
|
| 55 |
+
const Triangle &triangle = triangles[tri_idx];
|
| 56 |
+
|
| 57 |
+
int binIdx = std::min(
|
| 58 |
+
BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
|
| 59 |
+
count[binIdx]++;
|
| 60 |
+
__m128 v0 = _mm_set_ps(triangle.v0.x, triangle.v0.y, 0.0f, 0.0f);
|
| 61 |
+
__m128 v1 = _mm_set_ps(triangle.v1.x, triangle.v1.y, 0.0f, 0.0f);
|
| 62 |
+
__m128 v2 = _mm_set_ps(triangle.v2.x, triangle.v2.y, 0.0f, 0.0f);
|
| 63 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
|
| 64 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
|
| 65 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
|
| 66 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
|
| 67 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
|
| 68 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
|
| 69 |
+
}
|
| 70 |
+
// gather data for the 7 planes between the 8 bins
|
| 71 |
+
__m128 leftMin4 = _mm_set_ps1(1e30f), rightMin4 = leftMin4;
|
| 72 |
+
__m128 leftMax4 = _mm_set_ps1(-1e30f), rightMax4 = leftMax4;
|
| 73 |
+
for (int i = 0; i < BINS - 1; i++) {
|
| 74 |
+
leftSum += count[i];
|
| 75 |
+
rightSum += count[BINS - 1 - i];
|
| 76 |
+
leftMin4 = _mm_min_ps(leftMin4, min4[i]);
|
| 77 |
+
rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
|
| 78 |
+
leftMax4 = _mm_max_ps(leftMax4, max4[i]);
|
| 79 |
+
rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
|
| 80 |
+
float le[4], re[4];
|
| 81 |
+
_mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
|
| 82 |
+
_mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
|
| 83 |
+
// SSE order goes from back to front
|
| 84 |
+
leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
|
| 85 |
+
rightCountArea[BINS - 2 - i] =
|
| 86 |
+
rightSum * (re[2] * re[3]); // 2D area calculation
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
#else
|
| 90 |
+
if constexpr (false) {
|
| 91 |
+
}
|
| 92 |
+
#endif
|
| 93 |
+
else {
|
| 94 |
+
struct Bin {
|
| 95 |
+
AABB bounds;
|
| 96 |
+
int triCount = 0;
|
| 97 |
+
} bins[BINS];
|
| 98 |
+
|
| 99 |
+
for (int i = node.start; i < node.end; i++) {
|
| 100 |
+
int tri_idx = triangle_indices[i];
|
| 101 |
+
const Triangle &triangle = triangles[tri_idx];
|
| 102 |
+
|
| 103 |
+
int binIdx = std::min(
|
| 104 |
+
BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
|
| 105 |
+
bins[binIdx].triCount++;
|
| 106 |
+
bins[binIdx].bounds.grow(triangle.v0);
|
| 107 |
+
bins[binIdx].bounds.grow(triangle.v1);
|
| 108 |
+
bins[binIdx].bounds.grow(triangle.v2);
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Gather data for the planes between the bins
|
| 112 |
+
AABB leftBox, rightBox;
|
| 113 |
+
|
| 114 |
+
for (int i = 0; i < BINS - 1; i++) {
|
| 115 |
+
leftSum += bins[i].triCount;
|
| 116 |
+
leftBox.grow(bins[i].bounds);
|
| 117 |
+
leftCountArea[i] = leftSum * leftBox.area();
|
| 118 |
+
|
| 119 |
+
rightSum += bins[BINS - 1 - i].triCount;
|
| 120 |
+
rightBox.grow(bins[BINS - 1 - i].bounds);
|
| 121 |
+
rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Calculate SAH cost for the planes
|
| 126 |
+
scale = (boundsMax - boundsMin) / BINS;
|
| 127 |
+
for (int i = 0; i < BINS - 1; i++) {
|
| 128 |
+
float planeCost = leftCountArea[i] + rightCountArea[i];
|
| 129 |
+
if (planeCost < best_cost) {
|
| 130 |
+
best_axis = axis;
|
| 131 |
+
best_pos = i + 1;
|
| 132 |
+
best_cost = planeCost;
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return best_cost;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
void BVH::update_node_bounds(BVHNode &node, AABB ¢roidBounds) {
|
| 141 |
+
#ifndef __ARM_ARCH_ISA_A64
|
| 142 |
+
#ifndef _MSC_VER
|
| 143 |
+
if (__builtin_cpu_supports("sse"))
|
| 144 |
+
#elif (defined(_M_AMD64) || defined(_M_X64))
|
| 145 |
+
// SSE supported on Windows
|
| 146 |
+
if constexpr (true)
|
| 147 |
+
#endif
|
| 148 |
+
{
|
| 149 |
+
__m128 min4 = _mm_set_ps1(1e30f), max4 = _mm_set_ps1(-1e30f);
|
| 150 |
+
__m128 cmin4 = _mm_set_ps1(1e30f), cmax4 = _mm_set_ps1(-1e30f);
|
| 151 |
+
|
| 152 |
+
for (int i = node.start; i < node.end; i += 2) {
|
| 153 |
+
int tri_idx1 = triangle_indices[i];
|
| 154 |
+
const Triangle &leafTri1 = triangles[tri_idx1];
|
| 155 |
+
// Check if the second actually exists in the node
|
| 156 |
+
__m128 v0, v1, v2, centroid;
|
| 157 |
+
if (i + 1 < node.end) {
|
| 158 |
+
int tri_idx2 = triangle_indices[i + 1];
|
| 159 |
+
const Triangle leafTri2 = triangles[tri_idx2];
|
| 160 |
+
|
| 161 |
+
v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
|
| 162 |
+
leafTri2.v0.y);
|
| 163 |
+
v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
|
| 164 |
+
leafTri2.v1.y);
|
| 165 |
+
v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
|
| 166 |
+
leafTri2.v2.y);
|
| 167 |
+
centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
|
| 168 |
+
leafTri2.centroid.x, leafTri2.centroid.y);
|
| 169 |
+
} else {
|
| 170 |
+
// Otherwise do some duplicated work
|
| 171 |
+
v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
|
| 172 |
+
leafTri1.v0.y);
|
| 173 |
+
v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
|
| 174 |
+
leafTri1.v1.y);
|
| 175 |
+
v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
|
| 176 |
+
leafTri1.v2.y);
|
| 177 |
+
centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
|
| 178 |
+
leafTri1.centroid.x, leafTri1.centroid.y);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
min4 = _mm_min_ps(min4, v0);
|
| 182 |
+
max4 = _mm_max_ps(max4, v0);
|
| 183 |
+
min4 = _mm_min_ps(min4, v1);
|
| 184 |
+
max4 = _mm_max_ps(max4, v1);
|
| 185 |
+
min4 = _mm_min_ps(min4, v2);
|
| 186 |
+
max4 = _mm_max_ps(max4, v2);
|
| 187 |
+
cmin4 = _mm_min_ps(cmin4, centroid);
|
| 188 |
+
cmax4 = _mm_max_ps(cmax4, centroid);
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
|
| 192 |
+
_mm_store_ps(min_values, min4);
|
| 193 |
+
_mm_store_ps(max_values, max4);
|
| 194 |
+
_mm_store_ps(cmin_values, cmin4);
|
| 195 |
+
_mm_store_ps(cmax_values, cmax4);
|
| 196 |
+
|
| 197 |
+
node.bbox.min.x = std::min(min_values[3], min_values[1]);
|
| 198 |
+
node.bbox.min.y = std::min(min_values[2], min_values[0]);
|
| 199 |
+
node.bbox.max.x = std::max(max_values[3], max_values[1]);
|
| 200 |
+
node.bbox.max.y = std::max(max_values[2], max_values[0]);
|
| 201 |
+
|
| 202 |
+
centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
|
| 203 |
+
centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
|
| 204 |
+
centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
|
| 205 |
+
centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
|
| 206 |
+
}
|
| 207 |
+
#else
|
| 208 |
+
if constexpr (false) {
|
| 209 |
+
}
|
| 210 |
+
#endif
|
| 211 |
+
{
|
| 212 |
+
node.bbox.invalidate();
|
| 213 |
+
centroidBounds.invalidate();
|
| 214 |
+
|
| 215 |
+
// Calculate the bounding box for the node
|
| 216 |
+
for (int i = node.start; i < node.end; ++i) {
|
| 217 |
+
int tri_idx = triangle_indices[i];
|
| 218 |
+
const Triangle &tri = triangles[tri_idx];
|
| 219 |
+
node.bbox.grow(tri.v0);
|
| 220 |
+
node.bbox.grow(tri.v1);
|
| 221 |
+
node.bbox.grow(tri.v2);
|
| 222 |
+
centroidBounds.grow(tri.centroid);
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
void BVH::build(const tb_float2 *vertices, const tb_int3 *indices,
|
| 228 |
+
const int64_t &num_indices) {
|
| 229 |
+
#ifdef TIMING
|
| 230 |
+
auto start = std::chrono::high_resolution_clock::now();
|
| 231 |
+
#endif
|
| 232 |
+
// Create triangles
|
| 233 |
+
for (size_t i = 0; i < num_indices; ++i) {
|
| 234 |
+
tb_int3 idx = indices[i];
|
| 235 |
+
triangles.push_back(
|
| 236 |
+
{vertices[idx.x], vertices[idx.y], vertices[idx.z], static_cast<int>(i),
|
| 237 |
+
triangle_centroid(vertices[idx.x], vertices[idx.y], vertices[idx.z])});
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
// Initialize triangle_indices
|
| 241 |
+
triangle_indices.resize(triangles.size());
|
| 242 |
+
std::iota(triangle_indices.begin(), triangle_indices.end(), 0);
|
| 243 |
+
|
| 244 |
+
// Build BVH nodes
|
| 245 |
+
// Reserve extra capacity to fix windows specific crashes
|
| 246 |
+
nodes.reserve(triangles.size() * 2 + 1);
|
| 247 |
+
nodes.push_back({}); // Create the root node
|
| 248 |
+
root = 0;
|
| 249 |
+
|
| 250 |
+
// Define a struct for queue entries
|
| 251 |
+
struct QueueEntry {
|
| 252 |
+
int node_idx;
|
| 253 |
+
int start;
|
| 254 |
+
int end;
|
| 255 |
+
};
|
| 256 |
+
|
| 257 |
+
// Queue for breadth-first traversal
|
| 258 |
+
std::queue<QueueEntry> node_queue;
|
| 259 |
+
node_queue.push({root, 0, (int)triangles.size()});
|
| 260 |
+
|
| 261 |
+
// Process each node in the queue
|
| 262 |
+
while (!node_queue.empty()) {
|
| 263 |
+
QueueEntry current = node_queue.front();
|
| 264 |
+
node_queue.pop();
|
| 265 |
+
|
| 266 |
+
int node_idx = current.node_idx;
|
| 267 |
+
int start = current.start;
|
| 268 |
+
int end = current.end;
|
| 269 |
+
|
| 270 |
+
BVHNode &node = nodes[node_idx];
|
| 271 |
+
node.start = start;
|
| 272 |
+
node.end = end;
|
| 273 |
+
|
| 274 |
+
// Calculate the bounding box for the node
|
| 275 |
+
AABB centroidBounds;
|
| 276 |
+
update_node_bounds(node, centroidBounds);
|
| 277 |
+
|
| 278 |
+
// Determine the best split using SAH
|
| 279 |
+
int best_axis, best_pos;
|
| 280 |
+
|
| 281 |
+
float splitCost =
|
| 282 |
+
find_best_split_plane(node, best_axis, best_pos, centroidBounds);
|
| 283 |
+
float nosplitCost = node.calculate_node_cost();
|
| 284 |
+
|
| 285 |
+
// Stop condition: if the best cost is greater than or equal to the parent's
|
| 286 |
+
// cost
|
| 287 |
+
if (splitCost >= nosplitCost) {
|
| 288 |
+
// Leaf node
|
| 289 |
+
node.left = node.right = -1;
|
| 290 |
+
continue;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
float scale =
|
| 294 |
+
BINS / (centroidBounds.max[best_axis] - centroidBounds.min[best_axis]);
|
| 295 |
+
int i = node.start;
|
| 296 |
+
int j = node.end - 1;
|
| 297 |
+
|
| 298 |
+
// Sort the triangle_indices in the range [start, end) based on the best
|
| 299 |
+
// axis
|
| 300 |
+
while (i <= j) {
|
| 301 |
+
// use the exact calculation we used for binning to prevent rare
|
| 302 |
+
// inaccuracies
|
| 303 |
+
int tri_idx = triangle_indices[i];
|
| 304 |
+
tb_float2 tcentr = triangles[tri_idx].centroid;
|
| 305 |
+
int binIdx = std::min(
|
| 306 |
+
BINS - 1,
|
| 307 |
+
(int)((tcentr[best_axis] - centroidBounds.min[best_axis]) * scale));
|
| 308 |
+
if (binIdx < best_pos)
|
| 309 |
+
i++;
|
| 310 |
+
else
|
| 311 |
+
std::swap(triangle_indices[i], triangle_indices[j--]);
|
| 312 |
+
}
|
| 313 |
+
int leftCount = i - node.start;
|
| 314 |
+
if (leftCount == 0 || leftCount == node.num_triangles()) {
|
| 315 |
+
// Leaf node
|
| 316 |
+
node.left = node.right = -1;
|
| 317 |
+
continue;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
int mid = i;
|
| 321 |
+
|
| 322 |
+
// Create and set left child
|
| 323 |
+
node.left = nodes.size();
|
| 324 |
+
nodes.push_back({});
|
| 325 |
+
node_queue.push({node.left, start, mid});
|
| 326 |
+
|
| 327 |
+
// Create and set right child
|
| 328 |
+
node = nodes[node_idx]; // Update the node - Potentially stale reference
|
| 329 |
+
node.right = nodes.size();
|
| 330 |
+
nodes.push_back({});
|
| 331 |
+
node_queue.push({node.right, mid, end});
|
| 332 |
+
}
|
| 333 |
+
#ifdef TIMING
|
| 334 |
+
auto end = std::chrono::high_resolution_clock::now();
|
| 335 |
+
std::chrono::duration<double> elapsed = end - start;
|
| 336 |
+
std::cout << "BVH build time: " << elapsed.count() << "s" << std::endl;
|
| 337 |
+
#endif
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
// Utility function to clamp a value between a minimum and a maximum
|
| 341 |
+
float clamp(float val, float minVal, float maxVal) {
|
| 342 |
+
return std::min(std::max(val, minVal), maxVal);
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
// Function to check if a point (xy) is inside a triangle defined by vertices
|
| 346 |
+
// v1, v2, v3
|
| 347 |
+
bool barycentric_coordinates(tb_float2 xy, tb_float2 v1, tb_float2 v2,
|
| 348 |
+
tb_float2 v3, float &u, float &v, float &w) {
|
| 349 |
+
// Vectors from v1 to v2, v3 and xy
|
| 350 |
+
tb_float2 v1v2 = {v2.x - v1.x, v2.y - v1.y};
|
| 351 |
+
tb_float2 v1v3 = {v3.x - v1.x, v3.y - v1.y};
|
| 352 |
+
tb_float2 xyv1 = {xy.x - v1.x, xy.y - v1.y};
|
| 353 |
+
|
| 354 |
+
// Dot products of the vectors
|
| 355 |
+
float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
|
| 356 |
+
float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
|
| 357 |
+
float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
|
| 358 |
+
float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
|
| 359 |
+
float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
|
| 360 |
+
|
| 361 |
+
// Calculate the barycentric coordinates
|
| 362 |
+
float denom = d00 * d11 - d01 * d01;
|
| 363 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
| 364 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
| 365 |
+
u = 1.0f - v - w;
|
| 366 |
+
|
| 367 |
+
// Check if the point is inside the triangle
|
| 368 |
+
return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
bool BVH::intersect(const tb_float2 &point, float &u, float &v, float &w,
|
| 372 |
+
int &index) const {
|
| 373 |
+
const int max_stack_size = 64;
|
| 374 |
+
int node_stack[max_stack_size];
|
| 375 |
+
int stack_size = 0;
|
| 376 |
+
|
| 377 |
+
node_stack[stack_size++] = root;
|
| 378 |
+
|
| 379 |
+
while (stack_size > 0) {
|
| 380 |
+
int node_idx = node_stack[--stack_size];
|
| 381 |
+
const BVHNode &node = nodes[node_idx];
|
| 382 |
+
|
| 383 |
+
if (node.is_leaf()) {
|
| 384 |
+
for (int i = node.start; i < node.end; ++i) {
|
| 385 |
+
const Triangle &tri = triangles[triangle_indices[i]];
|
| 386 |
+
if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w)) {
|
| 387 |
+
index = tri.index;
|
| 388 |
+
return true;
|
| 389 |
+
}
|
| 390 |
+
}
|
| 391 |
+
} else {
|
| 392 |
+
if (nodes[node.right].bbox.overlaps(point)) {
|
| 393 |
+
if (stack_size < max_stack_size) {
|
| 394 |
+
node_stack[stack_size++] = node.right;
|
| 395 |
+
} else {
|
| 396 |
+
// Handle stack overflow
|
| 397 |
+
throw std::runtime_error("Node stack overflow");
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
if (nodes[node.left].bbox.overlaps(point)) {
|
| 401 |
+
if (stack_size < max_stack_size) {
|
| 402 |
+
node_stack[stack_size++] = node.left;
|
| 403 |
+
} else {
|
| 404 |
+
// Handle stack overflow
|
| 405 |
+
throw std::runtime_error("Node stack overflow");
|
| 406 |
+
}
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
return false;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
torch::Tensor rasterize_cpu(torch::Tensor uv, torch::Tensor indices,
|
| 415 |
+
int64_t bake_resolution) {
|
| 416 |
+
int width = bake_resolution;
|
| 417 |
+
int height = bake_resolution;
|
| 418 |
+
int num_pixels = width * height;
|
| 419 |
+
torch::Tensor rast_result = torch::empty(
|
| 420 |
+
{bake_resolution, bake_resolution, 4},
|
| 421 |
+
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
|
| 422 |
+
|
| 423 |
+
float *rast_result_ptr = rast_result.contiguous().data_ptr<float>();
|
| 424 |
+
const tb_float2 *vertices = (tb_float2 *)uv.data_ptr<float>();
|
| 425 |
+
const tb_int3 *tris = (tb_int3 *)indices.data_ptr<int>();
|
| 426 |
+
|
| 427 |
+
BVH bvh;
|
| 428 |
+
bvh.build(vertices, tris, indices.size(0));
|
| 429 |
+
|
| 430 |
+
#ifdef TIMING
|
| 431 |
+
auto start = std::chrono::high_resolution_clock::now();
|
| 432 |
+
#endif
|
| 433 |
+
|
| 434 |
+
#pragma omp parallel for
|
| 435 |
+
for (int idx = 0; idx < num_pixels; ++idx) {
|
| 436 |
+
int x = idx / height;
|
| 437 |
+
int y = idx % height;
|
| 438 |
+
int idx_ = idx * 4; // Note: *4 because we're storing float4 per pixel
|
| 439 |
+
|
| 440 |
+
tb_float2 pixel_coord = {float(y) / height, float(x) / width};
|
| 441 |
+
pixel_coord.x = clamp(pixel_coord.x, 0.0f, 1.0f);
|
| 442 |
+
pixel_coord.y = 1.0f - clamp(pixel_coord.y, 0.0f, 1.0f);
|
| 443 |
+
|
| 444 |
+
float u, v, w;
|
| 445 |
+
int triangle_idx;
|
| 446 |
+
if (bvh.intersect(pixel_coord, u, v, w, triangle_idx)) {
|
| 447 |
+
rast_result_ptr[idx_ + 0] = u;
|
| 448 |
+
rast_result_ptr[idx_ + 1] = v;
|
| 449 |
+
rast_result_ptr[idx_ + 2] = w;
|
| 450 |
+
rast_result_ptr[idx_ + 3] = static_cast<float>(triangle_idx);
|
| 451 |
+
} else {
|
| 452 |
+
rast_result_ptr[idx_ + 0] = 0.0f;
|
| 453 |
+
rast_result_ptr[idx_ + 1] = 0.0f;
|
| 454 |
+
rast_result_ptr[idx_ + 2] = 0.0f;
|
| 455 |
+
rast_result_ptr[idx_ + 3] = -1.0f;
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
#ifdef TIMING
|
| 460 |
+
auto end = std::chrono::high_resolution_clock::now();
|
| 461 |
+
std::chrono::duration<double> elapsed = end - start;
|
| 462 |
+
std::cout << "Rasterization time: " << elapsed.count() << "s" << std::endl;
|
| 463 |
+
#endif
|
| 464 |
+
return rast_result;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
torch::Tensor interpolate_cpu(torch::Tensor attr, torch::Tensor indices,
|
| 468 |
+
torch::Tensor rast) {
|
| 469 |
+
#ifdef TIMING
|
| 470 |
+
auto start = std::chrono::high_resolution_clock::now();
|
| 471 |
+
#endif
|
| 472 |
+
int height = rast.size(0);
|
| 473 |
+
int width = rast.size(1);
|
| 474 |
+
torch::Tensor pos_bake = torch::empty(
|
| 475 |
+
{height, width, 3},
|
| 476 |
+
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
|
| 477 |
+
|
| 478 |
+
const float *attr_ptr = attr.contiguous().data_ptr<float>();
|
| 479 |
+
const int *indices_ptr = indices.contiguous().data_ptr<int>();
|
| 480 |
+
const float *rast_ptr = rast.contiguous().data_ptr<float>();
|
| 481 |
+
float *output_ptr = pos_bake.contiguous().data_ptr<float>();
|
| 482 |
+
|
| 483 |
+
int num_pixels = width * height;
|
| 484 |
+
|
| 485 |
+
#pragma omp parallel for
|
| 486 |
+
for (int idx = 0; idx < num_pixels; ++idx) {
|
| 487 |
+
int idx_ = idx * 4; // Index into the float4 array (4 floats per pixel)
|
| 488 |
+
tb_float3 barycentric = {
|
| 489 |
+
rast_ptr[idx_ + 0],
|
| 490 |
+
rast_ptr[idx_ + 1],
|
| 491 |
+
rast_ptr[idx_ + 2],
|
| 492 |
+
};
|
| 493 |
+
int triangle_idx = static_cast<int>(rast_ptr[idx_ + 3]);
|
| 494 |
+
|
| 495 |
+
if (triangle_idx < 0) {
|
| 496 |
+
output_ptr[idx * 3 + 0] = 0.0f;
|
| 497 |
+
output_ptr[idx * 3 + 1] = 0.0f;
|
| 498 |
+
output_ptr[idx * 3 + 2] = 0.0f;
|
| 499 |
+
continue;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
tb_int3 triangle = {indices_ptr[3 * triangle_idx + 0],
|
| 503 |
+
indices_ptr[3 * triangle_idx + 1],
|
| 504 |
+
indices_ptr[3 * triangle_idx + 2]};
|
| 505 |
+
tb_float3 v1 = {attr_ptr[3 * triangle.x + 0], attr_ptr[3 * triangle.x + 1],
|
| 506 |
+
attr_ptr[3 * triangle.x + 2]};
|
| 507 |
+
tb_float3 v2 = {attr_ptr[3 * triangle.y + 0], attr_ptr[3 * triangle.y + 1],
|
| 508 |
+
attr_ptr[3 * triangle.y + 2]};
|
| 509 |
+
tb_float3 v3 = {attr_ptr[3 * triangle.z + 0], attr_ptr[3 * triangle.z + 1],
|
| 510 |
+
attr_ptr[3 * triangle.z + 2]};
|
| 511 |
+
|
| 512 |
+
tb_float3 interpolated;
|
| 513 |
+
interpolated.x =
|
| 514 |
+
v1.x * barycentric.x + v2.x * barycentric.y + v3.x * barycentric.z;
|
| 515 |
+
interpolated.y =
|
| 516 |
+
v1.y * barycentric.x + v2.y * barycentric.y + v3.y * barycentric.z;
|
| 517 |
+
interpolated.z =
|
| 518 |
+
v1.z * barycentric.x + v2.z * barycentric.y + v3.z * barycentric.z;
|
| 519 |
+
|
| 520 |
+
output_ptr[idx * 3 + 0] = interpolated.x;
|
| 521 |
+
output_ptr[idx * 3 + 1] = interpolated.y;
|
| 522 |
+
output_ptr[idx * 3 + 2] = interpolated.z;
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
#ifdef TIMING
|
| 526 |
+
auto end = std::chrono::high_resolution_clock::now();
|
| 527 |
+
std::chrono::duration<double> elapsed = end - start;
|
| 528 |
+
std::cout << "Interpolation time: " << elapsed.count() << "s" << std::endl;
|
| 529 |
+
#endif
|
| 530 |
+
return pos_bake;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
// Registers _C as a Python extension module.
|
| 534 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
|
| 535 |
+
|
| 536 |
+
// Defines the operators
|
| 537 |
+
TORCH_LIBRARY(texture_baker_cpp, m) {
|
| 538 |
+
m.def("rasterize(Tensor uv, Tensor indices, int bake_resolution) -> Tensor");
|
| 539 |
+
m.def("interpolate(Tensor attr, Tensor indices, Tensor rast) -> Tensor");
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
// Registers CPP implementations
|
| 543 |
+
TORCH_LIBRARY_IMPL(texture_baker_cpp, CPU, m) {
|
| 544 |
+
m.impl("rasterize", &rasterize_cpu);
|
| 545 |
+
m.impl("interpolate", &interpolate_cpu);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
} // namespace texture_baker_cpp
|
texture_baker/texture_baker/csrc/baker.h
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__METAL__)
|
| 4 |
+
#define CUDA_ENABLED
|
| 5 |
+
#ifndef __METAL__
|
| 6 |
+
#define CUDA_HOST_DEVICE __host__ __device__
|
| 7 |
+
#define CUDA_DEVICE __device__
|
| 8 |
+
#define METAL_CONSTANT_MEM
|
| 9 |
+
#define METAL_THREAD_MEM
|
| 10 |
+
#else
|
| 11 |
+
#define tb_float2 float2
|
| 12 |
+
#define CUDA_HOST_DEVICE
|
| 13 |
+
#define CUDA_DEVICE
|
| 14 |
+
#define METAL_CONSTANT_MEM constant
|
| 15 |
+
#define METAL_THREAD_MEM thread
|
| 16 |
+
#endif
|
| 17 |
+
#else
|
| 18 |
+
#define CUDA_HOST_DEVICE
|
| 19 |
+
#define CUDA_DEVICE
|
| 20 |
+
#define METAL_CONSTANT_MEM
|
| 21 |
+
#define METAL_THREAD_MEM
|
| 22 |
+
#include <cfloat>
|
| 23 |
+
#include <limits>
|
| 24 |
+
#include <vector>
|
| 25 |
+
#endif
|
| 26 |
+
|
| 27 |
+
namespace texture_baker_cpp {
|
| 28 |
+
// Structure to represent a 2D point or vector
|
| 29 |
+
#ifndef __METAL__
|
| 30 |
+
union alignas(8) tb_float2 {
|
| 31 |
+
struct {
|
| 32 |
+
float x, y;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
float data[2];
|
| 36 |
+
|
| 37 |
+
float &operator[](size_t idx) {
|
| 38 |
+
if (idx > 1)
|
| 39 |
+
throw std::runtime_error("bad index");
|
| 40 |
+
return data[idx];
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
const float &operator[](size_t idx) const {
|
| 44 |
+
if (idx > 1)
|
| 45 |
+
throw std::runtime_error("bad index");
|
| 46 |
+
return data[idx];
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
bool operator==(const tb_float2 &rhs) const {
|
| 50 |
+
return x == rhs.x && y == rhs.y;
|
| 51 |
+
}
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
union alignas(4) tb_float3 {
|
| 55 |
+
struct {
|
| 56 |
+
float x, y, z;
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
float data[3];
|
| 60 |
+
|
| 61 |
+
float &operator[](size_t idx) {
|
| 62 |
+
if (idx > 2)
|
| 63 |
+
throw std::runtime_error("bad index");
|
| 64 |
+
return data[idx];
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
const float &operator[](size_t idx) const {
|
| 68 |
+
if (idx > 2)
|
| 69 |
+
throw std::runtime_error("bad index");
|
| 70 |
+
return data[idx];
|
| 71 |
+
}
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
union alignas(16) tb_float4 {
|
| 75 |
+
struct {
|
| 76 |
+
float x, y, z, w;
|
| 77 |
+
};
|
| 78 |
+
|
| 79 |
+
float data[4];
|
| 80 |
+
|
| 81 |
+
float &operator[](size_t idx) {
|
| 82 |
+
if (idx > 3)
|
| 83 |
+
throw std::runtime_error("bad index");
|
| 84 |
+
return data[idx];
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
const float &operator[](size_t idx) const {
|
| 88 |
+
if (idx > 3)
|
| 89 |
+
throw std::runtime_error("bad index");
|
| 90 |
+
return data[idx];
|
| 91 |
+
}
|
| 92 |
+
};
|
| 93 |
+
#endif
|
| 94 |
+
|
| 95 |
+
union alignas(4) tb_int3 {
|
| 96 |
+
struct {
|
| 97 |
+
int x, y, z;
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
int data[3];
|
| 101 |
+
#ifndef __METAL__
|
| 102 |
+
int &operator[](size_t idx) {
|
| 103 |
+
if (idx > 2)
|
| 104 |
+
throw std::runtime_error("bad index");
|
| 105 |
+
return data[idx];
|
| 106 |
+
}
|
| 107 |
+
#endif
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
// BVH structure to accelerate point-triangle intersection
|
| 111 |
+
struct alignas(16) AABB {
|
| 112 |
+
// Init bounding boxes with max/min
|
| 113 |
+
tb_float2 min = {FLT_MAX, FLT_MAX};
|
| 114 |
+
tb_float2 max = {FLT_MIN, FLT_MIN};
|
| 115 |
+
|
| 116 |
+
#ifndef CUDA_ENABLED
|
| 117 |
+
// grow the AABB to include a point
|
| 118 |
+
void grow(const tb_float2 &p) {
|
| 119 |
+
min.x = std::min(min.x, p.x);
|
| 120 |
+
min.y = std::min(min.y, p.y);
|
| 121 |
+
max.x = std::max(max.x, p.x);
|
| 122 |
+
max.y = std::max(max.y, p.y);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
void grow(const AABB &b) {
|
| 126 |
+
if (b.min.x != FLT_MAX) {
|
| 127 |
+
grow(b.min);
|
| 128 |
+
grow(b.max);
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
#endif
|
| 132 |
+
|
| 133 |
+
// Check if two AABBs overlap
|
| 134 |
+
bool overlaps(const METAL_THREAD_MEM AABB &other) const {
|
| 135 |
+
return min.x <= other.max.x && max.x >= other.min.x &&
|
| 136 |
+
min.y <= other.max.y && max.y >= other.min.y;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
bool overlaps(const METAL_THREAD_MEM tb_float2 &point) const {
|
| 140 |
+
return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
|
| 141 |
+
point.y <= max.y;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
#if defined(__NVCC__) || defined(__HIPCC__)
|
| 145 |
+
CUDA_DEVICE bool overlaps(const float2 &point) const {
|
| 146 |
+
return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
|
| 147 |
+
point.y <= max.y;
|
| 148 |
+
}
|
| 149 |
+
#endif
|
| 150 |
+
|
| 151 |
+
// Initialize AABB to an invalid state
|
| 152 |
+
void invalidate() {
|
| 153 |
+
min = {FLT_MAX, FLT_MAX};
|
| 154 |
+
max = {FLT_MIN, FLT_MIN};
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
// Calculate the area of the AABB
|
| 158 |
+
float area() const {
|
| 159 |
+
tb_float2 extent = {max.x - min.x, max.y - min.y};
|
| 160 |
+
return extent.x * extent.y;
|
| 161 |
+
}
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
+
struct BVHNode {
|
| 165 |
+
AABB bbox;
|
| 166 |
+
int start, end;
|
| 167 |
+
int left, right;
|
| 168 |
+
|
| 169 |
+
int num_triangles() const { return end - start; }
|
| 170 |
+
|
| 171 |
+
CUDA_HOST_DEVICE bool is_leaf() const { return left == -1 && right == -1; }
|
| 172 |
+
|
| 173 |
+
float calculate_node_cost() {
|
| 174 |
+
float area = bbox.area();
|
| 175 |
+
return num_triangles() * area;
|
| 176 |
+
}
|
| 177 |
+
};
|
| 178 |
+
|
| 179 |
+
struct Triangle {
|
| 180 |
+
tb_float2 v0, v1, v2;
|
| 181 |
+
int index;
|
| 182 |
+
tb_float2 centroid;
|
| 183 |
+
};
|
| 184 |
+
|
| 185 |
+
#ifndef __METAL__
|
| 186 |
+
struct BVH {
|
| 187 |
+
std::vector<BVHNode> nodes;
|
| 188 |
+
std::vector<Triangle> triangles;
|
| 189 |
+
std::vector<int> triangle_indices;
|
| 190 |
+
int root;
|
| 191 |
+
|
| 192 |
+
void build(const tb_float2 *vertices, const tb_int3 *indices,
|
| 193 |
+
const int64_t &num_indices);
|
| 194 |
+
bool intersect(const tb_float2 &point, float &u, float &v, float &w,
|
| 195 |
+
int &index) const;
|
| 196 |
+
|
| 197 |
+
void update_node_bounds(BVHNode &node, AABB ¢roidBounds);
|
| 198 |
+
float find_best_split_plane(const BVHNode &node, int &best_axis,
|
| 199 |
+
int &best_pos, AABB ¢roidBounds);
|
| 200 |
+
};
|
| 201 |
+
#endif
|
| 202 |
+
|
| 203 |
+
} // namespace texture_baker_cpp
|
texture_baker/texture_baker/csrc/baker_kernel.cu
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/ATen.h>
|
| 2 |
+
#include <ATen/Context.h>
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#include <torch/extension.h>
|
| 5 |
+
|
| 6 |
+
#include "baker.h"
|
| 7 |
+
|
| 8 |
+
// #define TIMING
|
| 9 |
+
|
| 10 |
+
#define STRINGIFY(x) #x
|
| 11 |
+
#define STR(x) STRINGIFY(x)
|
| 12 |
+
#define FILE_LINE __FILE__ ":" STR(__LINE__)
|
| 13 |
+
#define CUDA_CHECK_THROW(x) \
|
| 14 |
+
do { \
|
| 15 |
+
cudaError_t _result = x; \
|
| 16 |
+
if (_result != cudaSuccess) \
|
| 17 |
+
throw std::runtime_error(std::string(FILE_LINE " check failed " #x " failed: ") + cudaGetErrorString(_result)); \
|
| 18 |
+
} while(0)
|
| 19 |
+
|
| 20 |
+
#if defined(__HIPCC__)
|
| 21 |
+
#define cudaMallocAsync hipMallocAsync
|
| 22 |
+
#define cudaFreeAsync hipFreeAsync
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
namespace texture_baker_cpp
|
| 26 |
+
{
|
| 27 |
+
|
| 28 |
+
__device__ float3 operator+(const float3 &a, const float3 &b)
|
| 29 |
+
{
|
| 30 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
// xy: 2D test position
|
| 34 |
+
// v1: vertex position 1
|
| 35 |
+
// v2: vertex position 2
|
| 36 |
+
// v3: vertex position 3
|
| 37 |
+
//
|
| 38 |
+
__forceinline__ __device__ bool barycentric_coordinates(const float2 &xy, const tb_float2 &v1, const tb_float2 &v2, const tb_float2 &v3, float &u, float &v, float &w)
|
| 39 |
+
{
|
| 40 |
+
// Return true if the point (xy) is inside the triangle defined by the vertices v1, v2, v3.
|
| 41 |
+
// If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
|
| 42 |
+
float2 v1v2 = make_float2(v2.x - v1.x, v2.y - v1.y);
|
| 43 |
+
float2 v1v3 = make_float2(v3.x - v1.x, v3.y - v1.y);
|
| 44 |
+
float2 xyv1 = make_float2(xy.x - v1.x, xy.y - v1.y);
|
| 45 |
+
|
| 46 |
+
float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
|
| 47 |
+
float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
|
| 48 |
+
float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
|
| 49 |
+
float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
|
| 50 |
+
float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
|
| 51 |
+
|
| 52 |
+
float denom = d00 * d11 - d01 * d01;
|
| 53 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
| 54 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
| 55 |
+
u = 1.0f - v - w;
|
| 56 |
+
|
| 57 |
+
return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
__global__ void kernel_interpolate(const float3* __restrict__ attr, const int3* __restrict__ indices, const float4* __restrict__ rast, float3* __restrict__ output, int width, int height)
|
| 61 |
+
{
|
| 62 |
+
// Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
|
| 63 |
+
//int idx = x * width + y;
|
| 64 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 65 |
+
int x = idx / width;
|
| 66 |
+
int y = idx % width;
|
| 67 |
+
|
| 68 |
+
if (x >= width || y >= height)
|
| 69 |
+
return;
|
| 70 |
+
|
| 71 |
+
float4 barycentric = rast[idx];
|
| 72 |
+
int triangle_idx = int(barycentric.w);
|
| 73 |
+
|
| 74 |
+
if (triangle_idx < 0)
|
| 75 |
+
{
|
| 76 |
+
output[idx] = make_float3(0.0f, 0.0f, 0.0f);
|
| 77 |
+
return;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
float3 v1 = attr[indices[triangle_idx].x];
|
| 81 |
+
float3 v2 = attr[indices[triangle_idx].y];
|
| 82 |
+
float3 v3 = attr[indices[triangle_idx].z];
|
| 83 |
+
|
| 84 |
+
output[idx] = make_float3(v1.x * barycentric.x, v1.y * barycentric.x, v1.z * barycentric.x)
|
| 85 |
+
+ make_float3(v2.x * barycentric.y, v2.y * barycentric.y, v2.z * barycentric.y)
|
| 86 |
+
+ make_float3(v3.x * barycentric.z, v3.y * barycentric.z, v3.z * barycentric.z);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
__device__ bool bvh_intersect(
|
| 90 |
+
const BVHNode* __restrict__ nodes,
|
| 91 |
+
const Triangle* __restrict__ triangles,
|
| 92 |
+
const int* __restrict__ triangle_indices,
|
| 93 |
+
const int root,
|
| 94 |
+
const float2 &point,
|
| 95 |
+
float &u, float &v, float &w,
|
| 96 |
+
int &index)
|
| 97 |
+
{
|
| 98 |
+
constexpr int max_stack_size = 64;
|
| 99 |
+
int node_stack[max_stack_size];
|
| 100 |
+
int stack_size = 0;
|
| 101 |
+
|
| 102 |
+
node_stack[stack_size++] = root;
|
| 103 |
+
|
| 104 |
+
while (stack_size > 0)
|
| 105 |
+
{
|
| 106 |
+
int node_idx = node_stack[--stack_size];
|
| 107 |
+
const BVHNode &node = nodes[node_idx];
|
| 108 |
+
|
| 109 |
+
if (node.is_leaf())
|
| 110 |
+
{
|
| 111 |
+
for (int i = node.start; i < node.end; ++i)
|
| 112 |
+
{
|
| 113 |
+
const Triangle &tri = triangles[triangle_indices[i]];
|
| 114 |
+
if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
|
| 115 |
+
{
|
| 116 |
+
index = tri.index;
|
| 117 |
+
return true;
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
else
|
| 122 |
+
{
|
| 123 |
+
if (nodes[node.right].bbox.overlaps(point))
|
| 124 |
+
{
|
| 125 |
+
if (stack_size < max_stack_size)
|
| 126 |
+
{
|
| 127 |
+
node_stack[stack_size++] = node.right;
|
| 128 |
+
}
|
| 129 |
+
else
|
| 130 |
+
{
|
| 131 |
+
// Handle stack overflow
|
| 132 |
+
// Make sure NDEBUG is not defined (see setup.py)
|
| 133 |
+
assert(0 && "Node stack overflow");
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
if (nodes[node.left].bbox.overlaps(point))
|
| 137 |
+
{
|
| 138 |
+
if (stack_size < max_stack_size)
|
| 139 |
+
{
|
| 140 |
+
node_stack[stack_size++] = node.left;
|
| 141 |
+
}
|
| 142 |
+
else
|
| 143 |
+
{
|
| 144 |
+
// Handle stack overflow
|
| 145 |
+
// Make sure NDEBUG is not defined (see setup.py)
|
| 146 |
+
assert(0 && "Node stack overflow");
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
return false;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
__global__ void kernel_bake_uv(
|
| 156 |
+
float2* __restrict__ uv,
|
| 157 |
+
int3* __restrict__ indices,
|
| 158 |
+
float4* __restrict__ output,
|
| 159 |
+
const BVHNode* __restrict__ nodes,
|
| 160 |
+
const Triangle* __restrict__ triangles,
|
| 161 |
+
const int* __restrict__ triangle_indices,
|
| 162 |
+
const int root,
|
| 163 |
+
const int width,
|
| 164 |
+
const int height,
|
| 165 |
+
const int num_indices)
|
| 166 |
+
{
|
| 167 |
+
//int idx = x * width + y;
|
| 168 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 169 |
+
int x = idx / width;
|
| 170 |
+
int y = idx % width;
|
| 171 |
+
|
| 172 |
+
if (y >= width || x >= height)
|
| 173 |
+
return;
|
| 174 |
+
|
| 175 |
+
// We index x,y but the original coords are HW. So swap them
|
| 176 |
+
float2 pixel_coord = make_float2(float(y) / height, float(x) / width);
|
| 177 |
+
pixel_coord.x = fminf(fmaxf(pixel_coord.x, 0.0f), 1.0f);
|
| 178 |
+
pixel_coord.y = 1.0f - fminf(fmaxf(pixel_coord.y, 0.0f), 1.0f);
|
| 179 |
+
|
| 180 |
+
float u, v, w;
|
| 181 |
+
int triangle_idx;
|
| 182 |
+
bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
|
| 183 |
+
|
| 184 |
+
if (hit)
|
| 185 |
+
{
|
| 186 |
+
output[idx] = make_float4(u, v, w, float(triangle_idx));
|
| 187 |
+
return;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
output[idx] = make_float4(0.0f, 0.0f, 0.0f, -1.0f);
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
torch::Tensor rasterize_gpu(
|
| 194 |
+
torch::Tensor uv,
|
| 195 |
+
torch::Tensor indices,
|
| 196 |
+
int64_t bake_resolution)
|
| 197 |
+
{
|
| 198 |
+
#ifdef TIMING
|
| 199 |
+
auto start = std::chrono::high_resolution_clock::now();
|
| 200 |
+
#endif
|
| 201 |
+
constexpr int block_size = 16 * 16;
|
| 202 |
+
int grid_size = bake_resolution * bake_resolution / block_size;
|
| 203 |
+
dim3 block_dims(block_size, 1, 1);
|
| 204 |
+
dim3 grid_dims(grid_size, 1, 1);
|
| 205 |
+
|
| 206 |
+
int num_indices = indices.size(0);
|
| 207 |
+
|
| 208 |
+
int width = bake_resolution;
|
| 209 |
+
int height = bake_resolution;
|
| 210 |
+
|
| 211 |
+
// Step 1: create an empty tensor to store the output.
|
| 212 |
+
torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
|
| 213 |
+
|
| 214 |
+
auto vertices_cpu = uv.contiguous().cpu();
|
| 215 |
+
auto indices_cpu = indices.contiguous().cpu();
|
| 216 |
+
|
| 217 |
+
const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
|
| 218 |
+
const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
|
| 219 |
+
|
| 220 |
+
BVH bvh;
|
| 221 |
+
bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
|
| 222 |
+
|
| 223 |
+
BVHNode *nodes_gpu = nullptr;
|
| 224 |
+
Triangle *triangles_gpu = nullptr;
|
| 225 |
+
int *triangle_indices_gpu = nullptr;
|
| 226 |
+
const int bvh_root = bvh.root;
|
| 227 |
+
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
| 228 |
+
|
| 229 |
+
CUDA_CHECK_THROW(cudaMallocAsync(&nodes_gpu, sizeof(BVHNode) * bvh.nodes.size(), cuda_stream));
|
| 230 |
+
CUDA_CHECK_THROW(cudaMallocAsync(&triangles_gpu, sizeof(Triangle) * bvh.triangles.size(), cuda_stream));
|
| 231 |
+
CUDA_CHECK_THROW(cudaMallocAsync(&triangle_indices_gpu, sizeof(int) * bvh.triangle_indices.size(), cuda_stream));
|
| 232 |
+
|
| 233 |
+
CUDA_CHECK_THROW(cudaMemcpyAsync(nodes_gpu, bvh.nodes.data(), sizeof(BVHNode) * bvh.nodes.size(), cudaMemcpyHostToDevice, cuda_stream));
|
| 234 |
+
CUDA_CHECK_THROW(cudaMemcpyAsync(triangles_gpu, bvh.triangles.data(), sizeof(Triangle) * bvh.triangles.size(), cudaMemcpyHostToDevice, cuda_stream));
|
| 235 |
+
CUDA_CHECK_THROW(cudaMemcpyAsync(triangle_indices_gpu, bvh.triangle_indices.data(), sizeof(int) * bvh.triangle_indices.size(), cudaMemcpyHostToDevice, cuda_stream));
|
| 236 |
+
|
| 237 |
+
kernel_bake_uv<<<grid_dims, block_dims, 0, cuda_stream>>>(
|
| 238 |
+
(float2 *)uv.contiguous().data_ptr<float>(),
|
| 239 |
+
(int3 *)indices.contiguous().data_ptr<int>(),
|
| 240 |
+
(float4 *)rast_result.contiguous().data_ptr<float>(),
|
| 241 |
+
nodes_gpu,
|
| 242 |
+
triangles_gpu,
|
| 243 |
+
triangle_indices_gpu,
|
| 244 |
+
bvh_root,
|
| 245 |
+
width,
|
| 246 |
+
height,
|
| 247 |
+
num_indices);
|
| 248 |
+
|
| 249 |
+
CUDA_CHECK_THROW(cudaFreeAsync(nodes_gpu, cuda_stream));
|
| 250 |
+
CUDA_CHECK_THROW(cudaFreeAsync(triangles_gpu, cuda_stream));
|
| 251 |
+
CUDA_CHECK_THROW(cudaFreeAsync(triangle_indices_gpu, cuda_stream));
|
| 252 |
+
|
| 253 |
+
#ifdef TIMING
|
| 254 |
+
CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
|
| 255 |
+
auto end = std::chrono::high_resolution_clock::now();
|
| 256 |
+
std::chrono::duration<double> elapsed = end - start;
|
| 257 |
+
std::cout << "Rasterization time (CUDA): " << elapsed.count() << "s" << std::endl;
|
| 258 |
+
#endif
|
| 259 |
+
return rast_result;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
torch::Tensor interpolate_gpu(
|
| 263 |
+
torch::Tensor attr,
|
| 264 |
+
torch::Tensor indices,
|
| 265 |
+
torch::Tensor rast)
|
| 266 |
+
{
|
| 267 |
+
#ifdef TIMING
|
| 268 |
+
auto start = std::chrono::high_resolution_clock::now();
|
| 269 |
+
#endif
|
| 270 |
+
constexpr int block_size = 16 * 16;
|
| 271 |
+
int grid_size = rast.size(0) * rast.size(0) / block_size;
|
| 272 |
+
dim3 block_dims(block_size, 1, 1);
|
| 273 |
+
dim3 grid_dims(grid_size, 1, 1);
|
| 274 |
+
|
| 275 |
+
// Step 1: create an empty tensor to store the output.
|
| 276 |
+
torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
|
| 277 |
+
|
| 278 |
+
int width = rast.size(0);
|
| 279 |
+
int height = rast.size(1);
|
| 280 |
+
|
| 281 |
+
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
| 282 |
+
|
| 283 |
+
kernel_interpolate<<<grid_dims, block_dims, 0, cuda_stream>>>(
|
| 284 |
+
(float3 *)attr.contiguous().data_ptr<float>(),
|
| 285 |
+
(int3 *)indices.contiguous().data_ptr<int>(),
|
| 286 |
+
(float4 *)rast.contiguous().data_ptr<float>(),
|
| 287 |
+
(float3 *)pos_bake.contiguous().data_ptr<float>(),
|
| 288 |
+
width,
|
| 289 |
+
height);
|
| 290 |
+
#ifdef TIMING
|
| 291 |
+
CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
|
| 292 |
+
auto end = std::chrono::high_resolution_clock::now();
|
| 293 |
+
std::chrono::duration<double> elapsed = end - start;
|
| 294 |
+
std::cout << "Interpolation time (CUDA): " << elapsed.count() << "s" << std::endl;
|
| 295 |
+
#endif
|
| 296 |
+
return pos_bake;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
// Registers CUDA implementations
|
| 300 |
+
TORCH_LIBRARY_IMPL(texture_baker_cpp, CUDA, m)
|
| 301 |
+
{
|
| 302 |
+
m.impl("rasterize", &rasterize_gpu);
|
| 303 |
+
m.impl("interpolate", &interpolate_gpu);
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
}
|
texture_baker/texture_baker/csrc/baker_kernel.metal
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_stdlib>
|
| 2 |
+
using namespace metal;
|
| 3 |
+
|
| 4 |
+
// This header is inlined manually
|
| 5 |
+
//#include "baker.h"
|
| 6 |
+
|
| 7 |
+
// Use the texture_baker_cpp so it can use the classes from baker.h
|
| 8 |
+
using namespace texture_baker_cpp;
|
| 9 |
+
|
| 10 |
+
// Utility function to compute barycentric coordinates
|
| 11 |
+
bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, thread float &u, thread float &v, thread float &w) {
|
| 12 |
+
float2 v1v2 = v2 - v1;
|
| 13 |
+
float2 v1v3 = v3 - v1;
|
| 14 |
+
float2 xyv1 = xy - v1;
|
| 15 |
+
|
| 16 |
+
float d00 = dot(v1v2, v1v2);
|
| 17 |
+
float d01 = dot(v1v2, v1v3);
|
| 18 |
+
float d11 = dot(v1v3, v1v3);
|
| 19 |
+
float d20 = dot(xyv1, v1v2);
|
| 20 |
+
float d21 = dot(xyv1, v1v3);
|
| 21 |
+
|
| 22 |
+
float denom = d00 * d11 - d01 * d01;
|
| 23 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
| 24 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
| 25 |
+
u = 1.0f - v - w;
|
| 26 |
+
|
| 27 |
+
return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// Kernel function for interpolation
|
| 31 |
+
kernel void kernel_interpolate(constant packed_float3 *attr [[buffer(0)]],
|
| 32 |
+
constant packed_int3 *indices [[buffer(1)]],
|
| 33 |
+
constant packed_float4 *rast [[buffer(2)]],
|
| 34 |
+
device packed_float3 *output [[buffer(3)]],
|
| 35 |
+
constant int &width [[buffer(4)]],
|
| 36 |
+
constant int &height [[buffer(5)]],
|
| 37 |
+
uint3 blockIdx [[threadgroup_position_in_grid]],
|
| 38 |
+
uint3 threadIdx [[thread_position_in_threadgroup]],
|
| 39 |
+
uint3 blockDim [[threads_per_threadgroup]])
|
| 40 |
+
{
|
| 41 |
+
// Calculate global position using threadgroup and thread positions
|
| 42 |
+
int x = blockIdx.x * blockDim.x + threadIdx.x;
|
| 43 |
+
int y = blockIdx.y * blockDim.y + threadIdx.y;
|
| 44 |
+
|
| 45 |
+
if (x >= width || y >= height) return;
|
| 46 |
+
|
| 47 |
+
int idx = y * width + x;
|
| 48 |
+
float4 barycentric = rast[idx];
|
| 49 |
+
int triangle_idx = int(barycentric.w);
|
| 50 |
+
|
| 51 |
+
if (triangle_idx < 0) {
|
| 52 |
+
output[idx] = float3(0.0f, 0.0f, 0.0f);
|
| 53 |
+
return;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
float3 v1 = attr[indices[triangle_idx].x];
|
| 57 |
+
float3 v2 = attr[indices[triangle_idx].y];
|
| 58 |
+
float3 v3 = attr[indices[triangle_idx].z];
|
| 59 |
+
|
| 60 |
+
output[idx] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
bool bvh_intersect(
|
| 64 |
+
constant BVHNode* nodes,
|
| 65 |
+
constant Triangle* triangles,
|
| 66 |
+
constant int* triangle_indices,
|
| 67 |
+
const thread int root,
|
| 68 |
+
const thread float2 &point,
|
| 69 |
+
thread float &u, thread float &v, thread float &w,
|
| 70 |
+
thread int &index)
|
| 71 |
+
{
|
| 72 |
+
const int max_stack_size = 64;
|
| 73 |
+
thread int node_stack[max_stack_size];
|
| 74 |
+
int stack_size = 0;
|
| 75 |
+
|
| 76 |
+
node_stack[stack_size++] = root;
|
| 77 |
+
|
| 78 |
+
while (stack_size > 0)
|
| 79 |
+
{
|
| 80 |
+
int node_idx = node_stack[--stack_size];
|
| 81 |
+
BVHNode node = nodes[node_idx];
|
| 82 |
+
|
| 83 |
+
if (node.is_leaf())
|
| 84 |
+
{
|
| 85 |
+
for (int i = node.start; i < node.end; ++i)
|
| 86 |
+
{
|
| 87 |
+
constant Triangle &tri = triangles[triangle_indices[i]];
|
| 88 |
+
if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
|
| 89 |
+
{
|
| 90 |
+
index = tri.index;
|
| 91 |
+
return true;
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
else
|
| 96 |
+
{
|
| 97 |
+
BVHNode test_node = nodes[node.right];
|
| 98 |
+
if (test_node.bbox.overlaps(point))
|
| 99 |
+
{
|
| 100 |
+
if (stack_size < max_stack_size)
|
| 101 |
+
{
|
| 102 |
+
node_stack[stack_size++] = node.right;
|
| 103 |
+
}
|
| 104 |
+
else
|
| 105 |
+
{
|
| 106 |
+
// Handle stack overflow
|
| 107 |
+
// Sadly, metal doesn't support asserts (but you could try enabling metal validation layers)
|
| 108 |
+
return false;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
test_node = nodes[node.left];
|
| 112 |
+
if (test_node.bbox.overlaps(point))
|
| 113 |
+
{
|
| 114 |
+
if (stack_size < max_stack_size)
|
| 115 |
+
{
|
| 116 |
+
node_stack[stack_size++] = node.left;
|
| 117 |
+
}
|
| 118 |
+
else
|
| 119 |
+
{
|
| 120 |
+
// Handle stack overflow
|
| 121 |
+
return false;
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
return false;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
// Kernel function for baking UV
|
| 132 |
+
kernel void kernel_bake_uv(constant packed_float2 *uv [[buffer(0)]],
|
| 133 |
+
constant packed_int3 *indices [[buffer(1)]],
|
| 134 |
+
device packed_float4 *output [[buffer(2)]],
|
| 135 |
+
constant BVHNode *nodes [[buffer(3)]],
|
| 136 |
+
constant Triangle *triangles [[buffer(4)]],
|
| 137 |
+
constant int *triangle_indices [[buffer(5)]],
|
| 138 |
+
constant int &root [[buffer(6)]],
|
| 139 |
+
constant int &width [[buffer(7)]],
|
| 140 |
+
constant int &height [[buffer(8)]],
|
| 141 |
+
constant int &num_indices [[buffer(9)]],
|
| 142 |
+
uint3 blockIdx [[threadgroup_position_in_grid]],
|
| 143 |
+
uint3 threadIdx [[thread_position_in_threadgroup]],
|
| 144 |
+
uint3 blockDim [[threads_per_threadgroup]])
|
| 145 |
+
{
|
| 146 |
+
// Calculate global position using threadgroup and thread positions
|
| 147 |
+
int x = blockIdx.x * blockDim.x + threadIdx.x;
|
| 148 |
+
int y = blockIdx.y * blockDim.y + threadIdx.y;
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if (x >= width || y >= height) return;
|
| 152 |
+
|
| 153 |
+
int idx = x * width + y;
|
| 154 |
+
|
| 155 |
+
// Swap original coordinates
|
| 156 |
+
float2 pixel_coord = float2(float(y) / float(height), float(x) / float(width));
|
| 157 |
+
pixel_coord = clamp(pixel_coord, 0.0f, 1.0f);
|
| 158 |
+
pixel_coord.y = 1.0f - pixel_coord.y;
|
| 159 |
+
|
| 160 |
+
float u, v, w;
|
| 161 |
+
int triangle_idx;
|
| 162 |
+
bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
|
| 163 |
+
|
| 164 |
+
if (hit) {
|
| 165 |
+
output[idx] = float4(u, v, w, float(triangle_idx));
|
| 166 |
+
return;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
output[idx] = float4(0.0f, 0.0f, 0.0f, -1.0f);
|
| 170 |
+
}
|
texture_baker/texture_baker/csrc/baker_kernel.mm
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <ATen/Context.h>
|
| 4 |
+
#include "baker.h"
|
| 5 |
+
|
| 6 |
+
#import <Foundation/Foundation.h>
|
| 7 |
+
#import <Metal/Metal.h>
|
| 8 |
+
#include <filesystem>
|
| 9 |
+
|
| 10 |
+
// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
|
| 11 |
+
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
|
| 12 |
+
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
// Helper function to create a compute pipeline state object (PSO).
|
| 16 |
+
static inline id<MTLComputePipelineState> createComputePipelineState(id<MTLDevice> device, NSString* fullSource, std::string kernel_name) {
|
| 17 |
+
NSError *error = nil;
|
| 18 |
+
|
| 19 |
+
// Load the custom kernel shader.
|
| 20 |
+
MTLCompileOptions *options = [[MTLCompileOptions alloc] init];
|
| 21 |
+
// Add the preprocessor macro "__METAL__"
|
| 22 |
+
options.preprocessorMacros = @{@"__METAL__": @""};
|
| 23 |
+
id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: fullSource options:options error:&error];
|
| 24 |
+
TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String);
|
| 25 |
+
|
| 26 |
+
id<MTLFunction> customKernelFunction = [customKernelLibrary newFunctionWithName:[NSString stringWithUTF8String:kernel_name.c_str()]];
|
| 27 |
+
TORCH_CHECK(customKernelFunction, "Failed to create function state object for ", kernel_name.c_str());
|
| 28 |
+
|
| 29 |
+
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:customKernelFunction error:&error];
|
| 30 |
+
TORCH_CHECK(pso, error.localizedDescription.UTF8String);
|
| 31 |
+
|
| 32 |
+
return pso;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
std::filesystem::path get_extension_path() {
|
| 36 |
+
// Ensure the GIL is held before calling any Python C API function
|
| 37 |
+
PyGILState_STATE gstate = PyGILState_Ensure();
|
| 38 |
+
|
| 39 |
+
const char* module_name = "texture_baker";
|
| 40 |
+
|
| 41 |
+
// Import the module by name
|
| 42 |
+
PyObject* module = PyImport_ImportModule(module_name);
|
| 43 |
+
if (!module) {
|
| 44 |
+
PyGILState_Release(gstate);
|
| 45 |
+
throw std::runtime_error("Could not import the module: " + std::string(module_name));
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// Get the filename of the module
|
| 49 |
+
PyObject* filename_obj = PyModule_GetFilenameObject(module);
|
| 50 |
+
if (filename_obj) {
|
| 51 |
+
std::string path = PyUnicode_AsUTF8(filename_obj);
|
| 52 |
+
Py_DECREF(filename_obj);
|
| 53 |
+
PyGILState_Release(gstate);
|
| 54 |
+
|
| 55 |
+
// Get the directory part of the path (removing the __init__.py)
|
| 56 |
+
std::filesystem::path module_path = std::filesystem::path(path).parent_path();
|
| 57 |
+
|
| 58 |
+
// Append the 'csrc' directory to the path
|
| 59 |
+
module_path /= "csrc";
|
| 60 |
+
|
| 61 |
+
return module_path;
|
| 62 |
+
} else {
|
| 63 |
+
PyGILState_Release(gstate);
|
| 64 |
+
throw std::runtime_error("Could not retrieve the module filename.");
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
NSString *get_shader_sources_as_string()
|
| 69 |
+
{
|
| 70 |
+
const std::filesystem::path csrc_path = get_extension_path();
|
| 71 |
+
const std::string shader_path = (csrc_path / "baker_kernel.metal").string();
|
| 72 |
+
const std::string shader_header_path = (csrc_path / "baker.h").string();
|
| 73 |
+
// Load the Metal shader from the specified path
|
| 74 |
+
NSError *error = nil;
|
| 75 |
+
|
| 76 |
+
NSString* shaderHeaderSource = [
|
| 77 |
+
NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_header_path.c_str()]
|
| 78 |
+
encoding:NSUTF8StringEncoding
|
| 79 |
+
error:&error];
|
| 80 |
+
if (error) {
|
| 81 |
+
throw std::runtime_error("Failed to load baker.h: " + std::string(error.localizedDescription.UTF8String));
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
NSString* shaderSource = [
|
| 85 |
+
NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_path.c_str()]
|
| 86 |
+
encoding:NSUTF8StringEncoding
|
| 87 |
+
error:&error];
|
| 88 |
+
if (error) {
|
| 89 |
+
throw std::runtime_error("Failed to load Metal shader: " + std::string(error.localizedDescription.UTF8String));
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
NSString *fullSource = [shaderHeaderSource stringByAppendingString:shaderSource];
|
| 93 |
+
|
| 94 |
+
return fullSource;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
namespace texture_baker_cpp
|
| 98 |
+
{
|
| 99 |
+
torch::Tensor rasterize_gpu(
|
| 100 |
+
torch::Tensor uv,
|
| 101 |
+
torch::Tensor indices,
|
| 102 |
+
int64_t bake_resolution)
|
| 103 |
+
{
|
| 104 |
+
TORCH_CHECK(uv.device().is_mps(), "uv must be a MPS tensor");
|
| 105 |
+
TORCH_CHECK(uv.is_contiguous(), "uv must be contiguous");
|
| 106 |
+
TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
|
| 107 |
+
|
| 108 |
+
TORCH_CHECK(uv.scalar_type() == torch::kFloat32, "Unsupported data type: ", indices.scalar_type());
|
| 109 |
+
TORCH_CHECK(indices.scalar_type() == torch::kInt32, "Unsupported data type: ", indices.scalar_type());
|
| 110 |
+
|
| 111 |
+
torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
|
| 112 |
+
|
| 113 |
+
@autoreleasepool {
|
| 114 |
+
auto vertices_cpu = uv.contiguous().cpu();
|
| 115 |
+
auto indices_cpu = indices.contiguous().cpu();
|
| 116 |
+
|
| 117 |
+
const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
|
| 118 |
+
const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
|
| 119 |
+
|
| 120 |
+
BVH bvh;
|
| 121 |
+
bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
|
| 122 |
+
|
| 123 |
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
| 124 |
+
|
| 125 |
+
NSString *fullSource = get_shader_sources_as_string();
|
| 126 |
+
|
| 127 |
+
// Create a compute pipeline state object using the helper function
|
| 128 |
+
id<MTLComputePipelineState> bake_uv_PSO = createComputePipelineState(device, fullSource, "kernel_bake_uv");
|
| 129 |
+
|
| 130 |
+
// Get a reference to the command buffer for the MPS stream.
|
| 131 |
+
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
|
| 132 |
+
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
|
| 133 |
+
|
| 134 |
+
// Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
|
| 135 |
+
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
|
| 136 |
+
|
| 137 |
+
dispatch_sync(serialQueue, ^(){
|
| 138 |
+
// Start a compute pass.
|
| 139 |
+
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
| 140 |
+
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
|
| 141 |
+
|
| 142 |
+
// Get Metal buffers directly from PyTorch tensors
|
| 143 |
+
auto uv_buf = getMTLBufferStorage(uv.contiguous());
|
| 144 |
+
auto indices_buf = getMTLBufferStorage(indices.contiguous());
|
| 145 |
+
auto rast_result_buf = getMTLBufferStorage(rast_result);
|
| 146 |
+
|
| 147 |
+
const int width = bake_resolution;
|
| 148 |
+
const int height = bake_resolution;
|
| 149 |
+
const int num_indices = indices.size(0);
|
| 150 |
+
const int bvh_root = bvh.root;
|
| 151 |
+
|
| 152 |
+
// Wrap the existing CPU memory in Metal buffers with shared memory
|
| 153 |
+
id<MTLBuffer> nodesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.nodes.data() length:sizeof(BVHNode) * bvh.nodes.size() options:MTLResourceStorageModeShared deallocator:nil];
|
| 154 |
+
id<MTLBuffer> trianglesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangles.data() length:sizeof(Triangle) * bvh.triangles.size() options:MTLResourceStorageModeShared deallocator:nil];
|
| 155 |
+
id<MTLBuffer> triangleIndicesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangle_indices.data() length:sizeof(int) * bvh.triangle_indices.size() options:MTLResourceStorageModeShared deallocator:nil];
|
| 156 |
+
|
| 157 |
+
[computeEncoder setComputePipelineState:bake_uv_PSO];
|
| 158 |
+
[computeEncoder setBuffer:uv_buf offset:uv.storage_offset() * uv.element_size() atIndex:0];
|
| 159 |
+
[computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
|
| 160 |
+
[computeEncoder setBuffer:rast_result_buf offset:rast_result.storage_offset() * rast_result.element_size() atIndex:2];
|
| 161 |
+
[computeEncoder setBuffer:nodesBuffer offset:0 atIndex:3];
|
| 162 |
+
[computeEncoder setBuffer:trianglesBuffer offset:0 atIndex:4];
|
| 163 |
+
[computeEncoder setBuffer:triangleIndicesBuffer offset:0 atIndex:5];
|
| 164 |
+
[computeEncoder setBytes:&bvh_root length:sizeof(int) atIndex:6];
|
| 165 |
+
[computeEncoder setBytes:&width length:sizeof(int) atIndex:7];
|
| 166 |
+
[computeEncoder setBytes:&height length:sizeof(int) atIndex:8];
|
| 167 |
+
[computeEncoder setBytes:&num_indices length:sizeof(int) atIndex:9];
|
| 168 |
+
|
| 169 |
+
// Calculate a thread group size.
|
| 170 |
+
int block_size = 16;
|
| 171 |
+
MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
|
| 172 |
+
MTLSize numThreadgroups = MTLSizeMake(bake_resolution / block_size, bake_resolution / block_size, 1);
|
| 173 |
+
|
| 174 |
+
// Encode the compute command.
|
| 175 |
+
[computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
|
| 176 |
+
[computeEncoder endEncoding];
|
| 177 |
+
|
| 178 |
+
// Commit the work.
|
| 179 |
+
torch::mps::commit();
|
| 180 |
+
});
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
return rast_result;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
torch::Tensor interpolate_gpu(
|
| 187 |
+
torch::Tensor attr,
|
| 188 |
+
torch::Tensor indices,
|
| 189 |
+
torch::Tensor rast)
|
| 190 |
+
{
|
| 191 |
+
TORCH_CHECK(attr.is_contiguous(), "attr must be contiguous");
|
| 192 |
+
TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
|
| 193 |
+
TORCH_CHECK(rast.is_contiguous(), "rast must be contiguous");
|
| 194 |
+
|
| 195 |
+
torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
|
| 196 |
+
std::filesystem::path csrc_path = get_extension_path();
|
| 197 |
+
|
| 198 |
+
@autoreleasepool {
|
| 199 |
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
| 200 |
+
|
| 201 |
+
NSString *fullSource = get_shader_sources_as_string();
|
| 202 |
+
// Create a compute pipeline state object using the helper function
|
| 203 |
+
id<MTLComputePipelineState> interpolate_PSO = createComputePipelineState(device, fullSource, "kernel_interpolate");
|
| 204 |
+
|
| 205 |
+
// Get a reference to the command buffer for the MPS stream.
|
| 206 |
+
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
|
| 207 |
+
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
|
| 208 |
+
|
| 209 |
+
// Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
|
| 210 |
+
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
|
| 211 |
+
|
| 212 |
+
dispatch_sync(serialQueue, ^(){
|
| 213 |
+
// Start a compute pass.
|
| 214 |
+
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
| 215 |
+
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
|
| 216 |
+
|
| 217 |
+
// Get Metal buffers directly from PyTorch tensors
|
| 218 |
+
auto attr_buf = getMTLBufferStorage(attr.contiguous());
|
| 219 |
+
auto indices_buf = getMTLBufferStorage(indices.contiguous());
|
| 220 |
+
auto rast_buf = getMTLBufferStorage(rast.contiguous());
|
| 221 |
+
auto pos_bake_buf = getMTLBufferStorage(pos_bake);
|
| 222 |
+
|
| 223 |
+
int width = rast.size(0);
|
| 224 |
+
int height = rast.size(1);
|
| 225 |
+
|
| 226 |
+
[computeEncoder setComputePipelineState:interpolate_PSO];
|
| 227 |
+
[computeEncoder setBuffer:attr_buf offset:attr.storage_offset() * attr.element_size() atIndex:0];
|
| 228 |
+
[computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
|
| 229 |
+
[computeEncoder setBuffer:rast_buf offset:rast.storage_offset() * rast.element_size() atIndex:2];
|
| 230 |
+
[computeEncoder setBuffer:pos_bake_buf offset:pos_bake.storage_offset() * pos_bake.element_size() atIndex:3];
|
| 231 |
+
[computeEncoder setBytes:&width length:sizeof(int) atIndex:4];
|
| 232 |
+
[computeEncoder setBytes:&height length:sizeof(int) atIndex:5];
|
| 233 |
+
|
| 234 |
+
// Calculate a thread group size.
|
| 235 |
+
|
| 236 |
+
int block_size = 16;
|
| 237 |
+
MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
|
| 238 |
+
MTLSize numThreadgroups = MTLSizeMake(rast.size(0) / block_size, rast.size(0) / block_size, 1);
|
| 239 |
+
|
| 240 |
+
// Encode the compute command.
|
| 241 |
+
[computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
|
| 242 |
+
|
| 243 |
+
[computeEncoder endEncoding];
|
| 244 |
+
|
| 245 |
+
// Commit the work.
|
| 246 |
+
torch::mps::commit();
|
| 247 |
+
});
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
return pos_bake;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
// Registers MPS implementations
|
| 254 |
+
TORCH_LIBRARY_IMPL(texture_baker_cpp, MPS, m)
|
| 255 |
+
{
|
| 256 |
+
m.impl("rasterize", &rasterize_gpu);
|
| 257 |
+
m.impl("interpolate", &interpolate_gpu);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
}
|
uv_unwrapper/README.md
ADDED
|
File without changes
|
uv_unwrapper/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
uv_unwrapper/setup.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from setuptools import find_packages, setup
|
| 6 |
+
from torch.utils.cpp_extension import (
|
| 7 |
+
BuildExtension,
|
| 8 |
+
CppExtension,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
library_name = "uv_unwrapper"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_extensions():
|
| 15 |
+
debug_mode = os.getenv("DEBUG", "0") == "1"
|
| 16 |
+
if debug_mode:
|
| 17 |
+
print("Compiling in debug mode")
|
| 18 |
+
|
| 19 |
+
is_mac = True if torch.backends.mps.is_available() else False
|
| 20 |
+
use_native_arch = not is_mac and os.getenv("USE_NATIVE_ARCH", "1") == "1"
|
| 21 |
+
extension = CppExtension
|
| 22 |
+
|
| 23 |
+
extra_link_args = []
|
| 24 |
+
extra_compile_args = {
|
| 25 |
+
"cxx": (
|
| 26 |
+
[
|
| 27 |
+
"-O3" if not debug_mode else "-O0",
|
| 28 |
+
"-fdiagnostics-color=always",
|
| 29 |
+
("-Xclang " if is_mac else "") + "-fopenmp",
|
| 30 |
+
]
|
| 31 |
+
+ ["-march=native"]
|
| 32 |
+
if use_native_arch
|
| 33 |
+
else [] + ["-mmacosx-version-min=10.15"] if is_mac else []
|
| 34 |
+
),
|
| 35 |
+
}
|
| 36 |
+
if debug_mode:
|
| 37 |
+
extra_compile_args["cxx"].append("-g")
|
| 38 |
+
extra_compile_args["cxx"].append("-UNDEBUG")
|
| 39 |
+
extra_link_args.extend(["-O0", "-g"])
|
| 40 |
+
|
| 41 |
+
define_macros = []
|
| 42 |
+
extensions = []
|
| 43 |
+
|
| 44 |
+
this_dir = os.path.dirname(os.path.curdir)
|
| 45 |
+
sources = glob.glob(
|
| 46 |
+
os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if len(sources) == 0:
|
| 50 |
+
print("No source files found for extension, skipping extension compilation")
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
extensions.append(
|
| 54 |
+
extension(
|
| 55 |
+
name=f"{library_name}._C",
|
| 56 |
+
sources=sources,
|
| 57 |
+
define_macros=define_macros,
|
| 58 |
+
extra_compile_args=extra_compile_args,
|
| 59 |
+
extra_link_args=extra_link_args,
|
| 60 |
+
libraries=(
|
| 61 |
+
["c10", "torch", "torch_cpu", "torch_python"] + ["omp"]
|
| 62 |
+
if is_mac
|
| 63 |
+
else []
|
| 64 |
+
),
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
print(extensions)
|
| 69 |
+
|
| 70 |
+
return extensions
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
setup(
|
| 74 |
+
name=library_name,
|
| 75 |
+
version="0.0.1",
|
| 76 |
+
packages=find_packages(),
|
| 77 |
+
ext_modules=get_extensions(),
|
| 78 |
+
install_requires=[],
|
| 79 |
+
description="Box projection based UV unwrapper",
|
| 80 |
+
long_description=open("README.md").read(),
|
| 81 |
+
long_description_content_type="text/markdown",
|
| 82 |
+
cmdclass={"build_ext": BuildExtension},
|
| 83 |
+
)
|
uv_unwrapper/uv_unwrapper/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch # noqa: F401
|
| 2 |
+
|
| 3 |
+
from . import _C # noqa: F401
|
| 4 |
+
from .unwrap import Unwrapper
|
| 5 |
+
|
| 6 |
+
__all__ = ["Unwrapper"]
|
uv_unwrapper/uv_unwrapper/csrc/bvh.cpp
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
#include "bvh.h"
|
| 4 |
+
#include "common.h"
|
| 5 |
+
#include <cstring>
|
| 6 |
+
#include <iostream>
|
| 7 |
+
#include <queue>
|
| 8 |
+
#include <tuple>
|
| 9 |
+
#include <utility>
|
| 10 |
+
|
| 11 |
+
namespace UVUnwrapper {
|
| 12 |
+
BVH::BVH(Triangle *tri, int *actual_idx, const size_t &num_indices) {
|
| 13 |
+
// Copty tri to triangle
|
| 14 |
+
triangle = new Triangle[num_indices];
|
| 15 |
+
memcpy(triangle, tri, num_indices * sizeof(Triangle));
|
| 16 |
+
|
| 17 |
+
// Copy actual_idx to actualIdx
|
| 18 |
+
actualIdx = new int[num_indices];
|
| 19 |
+
memcpy(actualIdx, actual_idx, num_indices * sizeof(int));
|
| 20 |
+
|
| 21 |
+
triIdx = new int[num_indices];
|
| 22 |
+
triCount = num_indices;
|
| 23 |
+
|
| 24 |
+
bvhNode = new BVHNode[triCount * 2 + 64];
|
| 25 |
+
nodesUsed = 2;
|
| 26 |
+
memset(bvhNode, 0, triCount * 2 * sizeof(BVHNode));
|
| 27 |
+
|
| 28 |
+
// populate triangle index array
|
| 29 |
+
for (int i = 0; i < triCount; i++)
|
| 30 |
+
triIdx[i] = i;
|
| 31 |
+
|
| 32 |
+
BVHNode &root = bvhNode[0];
|
| 33 |
+
|
| 34 |
+
root.start = 0, root.end = triCount;
|
| 35 |
+
AABB centroidBounds;
|
| 36 |
+
UpdateNodeBounds(0, centroidBounds);
|
| 37 |
+
|
| 38 |
+
// subdivide recursively
|
| 39 |
+
Subdivide(0, nodesUsed, centroidBounds);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
BVH::BVH(const BVH &other)
|
| 43 |
+
: BVH(other.triangle, other.triIdx, other.triCount) {}
|
| 44 |
+
|
| 45 |
+
BVH::BVH(BVH &&other) noexcept // move constructor
|
| 46 |
+
: triIdx(std::exchange(other.triIdx, nullptr)),
|
| 47 |
+
actualIdx(std::exchange(other.actualIdx, nullptr)),
|
| 48 |
+
triangle(std::exchange(other.triangle, nullptr)),
|
| 49 |
+
bvhNode(std::exchange(other.bvhNode, nullptr)) {}
|
| 50 |
+
|
| 51 |
+
BVH &BVH::operator=(const BVH &other) // copy assignment
|
| 52 |
+
{
|
| 53 |
+
return *this = BVH(other);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
BVH &BVH::operator=(BVH &&other) noexcept // move assignment
|
| 57 |
+
{
|
| 58 |
+
std::swap(triIdx, other.triIdx);
|
| 59 |
+
std::swap(actualIdx, other.actualIdx);
|
| 60 |
+
std::swap(triangle, other.triangle);
|
| 61 |
+
std::swap(bvhNode, other.bvhNode);
|
| 62 |
+
std::swap(triCount, other.triCount);
|
| 63 |
+
std::swap(nodesUsed, other.nodesUsed);
|
| 64 |
+
return *this;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
BVH::~BVH() {
|
| 68 |
+
if (triIdx)
|
| 69 |
+
delete[] triIdx;
|
| 70 |
+
if (triangle)
|
| 71 |
+
delete[] triangle;
|
| 72 |
+
if (actualIdx)
|
| 73 |
+
delete[] actualIdx;
|
| 74 |
+
if (bvhNode)
|
| 75 |
+
delete[] bvhNode;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
void BVH::UpdateNodeBounds(unsigned int nodeIdx, AABB ¢roidBounds) {
|
| 79 |
+
BVHNode &node = bvhNode[nodeIdx];
|
| 80 |
+
#ifndef __ARM_ARCH_ISA_A64
|
| 81 |
+
#ifndef _MSC_VER
|
| 82 |
+
if (__builtin_cpu_supports("sse"))
|
| 83 |
+
#elif (defined(_M_AMD64) || defined(_M_X64))
|
| 84 |
+
// SSE supported on Windows
|
| 85 |
+
if constexpr (true)
|
| 86 |
+
#endif
|
| 87 |
+
{
|
| 88 |
+
__m128 min4 = _mm_set_ps1(FLT_MAX), max4 = _mm_set_ps1(FLT_MIN);
|
| 89 |
+
__m128 cmin4 = _mm_set_ps1(FLT_MAX), cmax4 = _mm_set_ps1(FLT_MIN);
|
| 90 |
+
for (int i = node.start; i < node.end; i += 2) {
|
| 91 |
+
Triangle &leafTri1 = triangle[triIdx[i]];
|
| 92 |
+
__m128 v0, v1, v2, centroid;
|
| 93 |
+
if (i + 1 < node.end) {
|
| 94 |
+
const Triangle leafTri2 = triangle[triIdx[i + 1]];
|
| 95 |
+
|
| 96 |
+
v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
|
| 97 |
+
leafTri2.v0.y);
|
| 98 |
+
v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
|
| 99 |
+
leafTri2.v1.y);
|
| 100 |
+
v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
|
| 101 |
+
leafTri2.v2.y);
|
| 102 |
+
centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
|
| 103 |
+
leafTri2.centroid.x, leafTri2.centroid.y);
|
| 104 |
+
} else {
|
| 105 |
+
// Otherwise do some duplicated work
|
| 106 |
+
v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
|
| 107 |
+
leafTri1.v0.y);
|
| 108 |
+
v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
|
| 109 |
+
leafTri1.v1.y);
|
| 110 |
+
v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
|
| 111 |
+
leafTri1.v2.y);
|
| 112 |
+
centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
|
| 113 |
+
leafTri1.centroid.x, leafTri1.centroid.y);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
min4 = _mm_min_ps(min4, v0);
|
| 117 |
+
max4 = _mm_max_ps(max4, v0);
|
| 118 |
+
min4 = _mm_min_ps(min4, v1);
|
| 119 |
+
max4 = _mm_max_ps(max4, v1);
|
| 120 |
+
min4 = _mm_min_ps(min4, v2);
|
| 121 |
+
max4 = _mm_max_ps(max4, v2);
|
| 122 |
+
cmin4 = _mm_min_ps(cmin4, centroid);
|
| 123 |
+
cmax4 = _mm_max_ps(cmax4, centroid);
|
| 124 |
+
}
|
| 125 |
+
float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
|
| 126 |
+
_mm_store_ps(min_values, min4);
|
| 127 |
+
_mm_store_ps(max_values, max4);
|
| 128 |
+
_mm_store_ps(cmin_values, cmin4);
|
| 129 |
+
_mm_store_ps(cmax_values, cmax4);
|
| 130 |
+
|
| 131 |
+
node.bbox.min.x = std::min(min_values[3], min_values[1]);
|
| 132 |
+
node.bbox.min.y = std::min(min_values[2], min_values[0]);
|
| 133 |
+
node.bbox.max.x = std::max(max_values[3], max_values[1]);
|
| 134 |
+
node.bbox.max.y = std::max(max_values[2], max_values[0]);
|
| 135 |
+
|
| 136 |
+
centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
|
| 137 |
+
centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
|
| 138 |
+
centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
|
| 139 |
+
centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
|
| 140 |
+
}
|
| 141 |
+
#else
|
| 142 |
+
if constexpr (false) {
|
| 143 |
+
}
|
| 144 |
+
#endif
|
| 145 |
+
else {
|
| 146 |
+
node.bbox.invalidate();
|
| 147 |
+
centroidBounds.invalidate();
|
| 148 |
+
|
| 149 |
+
// Calculate the bounding box for the node
|
| 150 |
+
for (int i = node.start; i < node.end; ++i) {
|
| 151 |
+
const Triangle &tri = triangle[triIdx[i]];
|
| 152 |
+
node.bbox.grow(tri.v0);
|
| 153 |
+
node.bbox.grow(tri.v1);
|
| 154 |
+
node.bbox.grow(tri.v2);
|
| 155 |
+
centroidBounds.grow(tri.centroid);
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
void BVH::Subdivide(unsigned int root_idx, unsigned int &nodePtr,
|
| 161 |
+
AABB &rootCentroidBounds) {
|
| 162 |
+
// Create a queue for the nodes to be subdivided
|
| 163 |
+
std::queue<std::tuple<unsigned int, AABB>> nodeQueue;
|
| 164 |
+
nodeQueue.push(std::make_tuple(root_idx, rootCentroidBounds));
|
| 165 |
+
|
| 166 |
+
while (!nodeQueue.empty()) {
|
| 167 |
+
// Get the next node to process from the queue
|
| 168 |
+
auto [node_idx, centroidBounds] = nodeQueue.front();
|
| 169 |
+
nodeQueue.pop();
|
| 170 |
+
BVHNode &node = bvhNode[node_idx];
|
| 171 |
+
|
| 172 |
+
// Check if left is -1 and right not or vice versa
|
| 173 |
+
|
| 174 |
+
int axis, splitPos;
|
| 175 |
+
float cost = FindBestSplitPlane(node, axis, splitPos, centroidBounds);
|
| 176 |
+
|
| 177 |
+
if (cost >= node.calculate_node_cost()) {
|
| 178 |
+
node.left = node.right = -1;
|
| 179 |
+
continue; // Move on to the next node in the queue
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
int i = node.start;
|
| 183 |
+
int j = node.end - 1;
|
| 184 |
+
float scale = BINS / (centroidBounds.max[axis] - centroidBounds.min[axis]);
|
| 185 |
+
while (i <= j) {
|
| 186 |
+
int binIdx =
|
| 187 |
+
std::min(BINS - 1, (int)((triangle[triIdx[i]].centroid[axis] -
|
| 188 |
+
centroidBounds.min[axis]) *
|
| 189 |
+
scale));
|
| 190 |
+
if (binIdx < splitPos)
|
| 191 |
+
i++;
|
| 192 |
+
else
|
| 193 |
+
std::swap(triIdx[i], triIdx[j--]);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
int leftCount = i - node.start;
|
| 197 |
+
if (leftCount == 0 || leftCount == (int)node.num_triangles()) {
|
| 198 |
+
node.left = node.right = -1;
|
| 199 |
+
continue; // Move on to the next node in the queue
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
int mid = i;
|
| 203 |
+
|
| 204 |
+
// Create child nodes
|
| 205 |
+
int leftChildIdx = nodePtr++;
|
| 206 |
+
int rightChildIdx = nodePtr++;
|
| 207 |
+
bvhNode[leftChildIdx].start = node.start;
|
| 208 |
+
bvhNode[leftChildIdx].end = mid;
|
| 209 |
+
bvhNode[rightChildIdx].start = mid;
|
| 210 |
+
bvhNode[rightChildIdx].end = node.end;
|
| 211 |
+
node.left = leftChildIdx;
|
| 212 |
+
node.right = rightChildIdx;
|
| 213 |
+
|
| 214 |
+
// Update the bounds for the child nodes and push them onto the queue
|
| 215 |
+
UpdateNodeBounds(leftChildIdx, centroidBounds);
|
| 216 |
+
nodeQueue.push(std::make_tuple(leftChildIdx, centroidBounds));
|
| 217 |
+
|
| 218 |
+
UpdateNodeBounds(rightChildIdx, centroidBounds);
|
| 219 |
+
nodeQueue.push(std::make_tuple(rightChildIdx, centroidBounds));
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
float BVH::FindBestSplitPlane(BVHNode &node, int &best_axis, int &best_pos,
|
| 224 |
+
AABB ¢roidBounds) {
|
| 225 |
+
float best_cost = FLT_MAX;
|
| 226 |
+
|
| 227 |
+
for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
|
| 228 |
+
{
|
| 229 |
+
float boundsMin = centroidBounds.min[axis];
|
| 230 |
+
float boundsMax = centroidBounds.max[axis];
|
| 231 |
+
// Or floating point precision
|
| 232 |
+
if ((boundsMin == boundsMax) || (boundsMax - boundsMin < 1e-8f)) {
|
| 233 |
+
continue;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
// populate the bins
|
| 237 |
+
float scale = BINS / (boundsMax - boundsMin);
|
| 238 |
+
float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
|
| 239 |
+
int leftSum = 0, rightSum = 0;
|
| 240 |
+
#ifndef __ARM_ARCH_ISA_A64
|
| 241 |
+
#ifndef _MSC_VER
|
| 242 |
+
if (__builtin_cpu_supports("sse"))
|
| 243 |
+
#elif (defined(_M_AMD64) || defined(_M_X64))
|
| 244 |
+
// SSE supported on Windows
|
| 245 |
+
if constexpr (true)
|
| 246 |
+
#endif
|
| 247 |
+
{
|
| 248 |
+
__m128 min4[BINS], max4[BINS];
|
| 249 |
+
unsigned int count[BINS];
|
| 250 |
+
for (unsigned int i = 0; i < BINS; i++)
|
| 251 |
+
min4[i] = _mm_set_ps1(FLT_MAX), max4[i] = _mm_set_ps1(FLT_MIN),
|
| 252 |
+
count[i] = 0;
|
| 253 |
+
for (int i = node.start; i < node.end; i++) {
|
| 254 |
+
Triangle &tri = triangle[triIdx[i]];
|
| 255 |
+
int binIdx =
|
| 256 |
+
std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
|
| 257 |
+
count[binIdx]++;
|
| 258 |
+
|
| 259 |
+
__m128 v0 = _mm_set_ps(tri.v0.x, tri.v0.y, 0.0f, 0.0f);
|
| 260 |
+
__m128 v1 = _mm_set_ps(tri.v1.x, tri.v1.y, 0.0f, 0.0f);
|
| 261 |
+
__m128 v2 = _mm_set_ps(tri.v2.x, tri.v2.y, 0.0f, 0.0f);
|
| 262 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
|
| 263 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
|
| 264 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
|
| 265 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
|
| 266 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
|
| 267 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
|
| 268 |
+
}
|
| 269 |
+
// gather data for the 7 planes between the 8 bins
|
| 270 |
+
__m128 leftMin4 = _mm_set_ps1(FLT_MAX), rightMin4 = leftMin4;
|
| 271 |
+
__m128 leftMax4 = _mm_set_ps1(FLT_MIN), rightMax4 = leftMax4;
|
| 272 |
+
for (int i = 0; i < BINS - 1; i++) {
|
| 273 |
+
leftSum += count[i];
|
| 274 |
+
rightSum += count[BINS - 1 - i];
|
| 275 |
+
leftMin4 = _mm_min_ps(leftMin4, min4[i]);
|
| 276 |
+
rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
|
| 277 |
+
leftMax4 = _mm_max_ps(leftMax4, max4[i]);
|
| 278 |
+
rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
|
| 279 |
+
float le[4], re[4];
|
| 280 |
+
_mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
|
| 281 |
+
_mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
|
| 282 |
+
// SSE order goes from back to front
|
| 283 |
+
leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
|
| 284 |
+
rightCountArea[BINS - 2 - i] =
|
| 285 |
+
rightSum * (re[2] * re[3]); // 2D area calculation
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
#else
|
| 289 |
+
if constexpr (false) {
|
| 290 |
+
}
|
| 291 |
+
#endif
|
| 292 |
+
else {
|
| 293 |
+
struct Bin {
|
| 294 |
+
AABB bounds;
|
| 295 |
+
int triCount = 0;
|
| 296 |
+
} bin[BINS];
|
| 297 |
+
for (int i = node.start; i < node.end; i++) {
|
| 298 |
+
Triangle &tri = triangle[triIdx[i]];
|
| 299 |
+
int binIdx =
|
| 300 |
+
std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
|
| 301 |
+
bin[binIdx].triCount++;
|
| 302 |
+
bin[binIdx].bounds.grow(tri.v0);
|
| 303 |
+
bin[binIdx].bounds.grow(tri.v1);
|
| 304 |
+
bin[binIdx].bounds.grow(tri.v2);
|
| 305 |
+
}
|
| 306 |
+
// gather data for the 7 planes between the 8 bins
|
| 307 |
+
AABB leftBox, rightBox;
|
| 308 |
+
for (int i = 0; i < BINS - 1; i++) {
|
| 309 |
+
leftSum += bin[i].triCount;
|
| 310 |
+
leftBox.grow(bin[i].bounds);
|
| 311 |
+
leftCountArea[i] = leftSum * leftBox.area();
|
| 312 |
+
rightSum += bin[BINS - 1 - i].triCount;
|
| 313 |
+
rightBox.grow(bin[BINS - 1 - i].bounds);
|
| 314 |
+
rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
// calculate SAH cost for the 7 planes
|
| 319 |
+
scale = (boundsMax - boundsMin) / BINS;
|
| 320 |
+
for (int i = 0; i < BINS - 1; i++) {
|
| 321 |
+
const float planeCost = leftCountArea[i] + rightCountArea[i];
|
| 322 |
+
if (planeCost < best_cost)
|
| 323 |
+
best_axis = axis, best_pos = i + 1, best_cost = planeCost;
|
| 324 |
+
}
|
| 325 |
+
}
|
| 326 |
+
return best_cost;
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
std::vector<int> BVH::Intersect(Triangle &tri_intersect) {
|
| 330 |
+
/**
|
| 331 |
+
* @brief Intersect a triangle with the BVH
|
| 332 |
+
*
|
| 333 |
+
* @param triangle the triangle to intersect
|
| 334 |
+
*
|
| 335 |
+
* @return -1 for no intersection, the index of the intersected triangle
|
| 336 |
+
* otherwise
|
| 337 |
+
*/
|
| 338 |
+
|
| 339 |
+
const int max_stack_size = 64;
|
| 340 |
+
int node_stack[max_stack_size];
|
| 341 |
+
int stack_size = 0;
|
| 342 |
+
std::vector<int> intersected_triangles;
|
| 343 |
+
|
| 344 |
+
node_stack[stack_size++] = 0; // Start with the root node (index 0)
|
| 345 |
+
while (stack_size > 0) {
|
| 346 |
+
int node_idx = node_stack[--stack_size];
|
| 347 |
+
const BVHNode &node = bvhNode[node_idx];
|
| 348 |
+
if (node.is_leaf()) {
|
| 349 |
+
for (int i = node.start; i < node.end; ++i) {
|
| 350 |
+
const Triangle &tri = triangle[triIdx[i]];
|
| 351 |
+
// Check that the triangle is not the same as the intersected triangle
|
| 352 |
+
if (tri == tri_intersect)
|
| 353 |
+
continue;
|
| 354 |
+
if (tri_intersect.overlaps(tri)) {
|
| 355 |
+
intersected_triangles.push_back(actualIdx[triIdx[i]]);
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
} else {
|
| 359 |
+
// Check right child first
|
| 360 |
+
if (bvhNode[node.right].bbox.overlaps(tri_intersect)) {
|
| 361 |
+
if (stack_size < max_stack_size) {
|
| 362 |
+
node_stack[stack_size++] = node.right;
|
| 363 |
+
} else {
|
| 364 |
+
throw std::runtime_error("Node stack overflow");
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
// Check left child
|
| 369 |
+
if (bvhNode[node.left].bbox.overlaps(tri_intersect)) {
|
| 370 |
+
if (stack_size < max_stack_size) {
|
| 371 |
+
node_stack[stack_size++] = node.left;
|
| 372 |
+
} else {
|
| 373 |
+
throw std::runtime_error("Node stack overflow");
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
return intersected_triangles; // Return all intersected triangle indices
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
} // namespace UVUnwrapper
|
uv_unwrapper/uv_unwrapper/csrc/bvh.h
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cfloat>
|
| 4 |
+
#include <cmath>
|
| 5 |
+
#ifndef __ARM_ARCH_ISA_A64
|
| 6 |
+
#include <immintrin.h>
|
| 7 |
+
#endif
|
| 8 |
+
#include <limits>
|
| 9 |
+
#include <vector>
|
| 10 |
+
|
| 11 |
+
#include "common.h"
|
| 12 |
+
#include "intersect.h"
|
| 13 |
+
/**
|
| 14 |
+
* Based on https://github.com/jbikker/bvh_article released under the unlicense.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
// bin count for binned BVH building
|
| 18 |
+
#define BINS 8
|
| 19 |
+
|
| 20 |
+
namespace UVUnwrapper {
|
| 21 |
+
// minimalist triangle struct
|
| 22 |
+
struct alignas(32) Triangle {
|
| 23 |
+
uv_float2 v0;
|
| 24 |
+
uv_float2 v1;
|
| 25 |
+
uv_float2 v2;
|
| 26 |
+
uv_float2 centroid;
|
| 27 |
+
|
| 28 |
+
bool overlaps(const Triangle &other) {
|
| 29 |
+
// return tri_tri_overlap_test_2d(v0, v1, v2, other.v0, other.v1, other.v2);
|
| 30 |
+
return triangle_triangle_intersection(v0, v1, v2, other.v0, other.v1,
|
| 31 |
+
other.v2);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
bool operator==(const Triangle &rhs) const {
|
| 35 |
+
return v0 == rhs.v0 && v1 == rhs.v1 && v2 == rhs.v2;
|
| 36 |
+
}
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
// minimalist AABB struct with grow functionality
|
| 40 |
+
struct alignas(16) AABB {
|
| 41 |
+
// Init bounding boxes with max/min
|
| 42 |
+
uv_float2 min = {FLT_MAX, FLT_MAX};
|
| 43 |
+
uv_float2 max = {FLT_MIN, FLT_MIN};
|
| 44 |
+
|
| 45 |
+
void grow(const uv_float2 &p) {
|
| 46 |
+
min.x = std::min(min.x, p.x);
|
| 47 |
+
min.y = std::min(min.y, p.y);
|
| 48 |
+
max.x = std::max(max.x, p.x);
|
| 49 |
+
max.y = std::max(max.y, p.y);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
void grow(const AABB &b) {
|
| 53 |
+
if (b.min.x != FLT_MAX) {
|
| 54 |
+
grow(b.min);
|
| 55 |
+
grow(b.max);
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
bool overlaps(const Triangle &tri) {
|
| 60 |
+
return triangle_aabb_intersection(min, max, tri.v0, tri.v1, tri.v2);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
float area() const {
|
| 64 |
+
uv_float2 extent = {max.x - min.x, max.y - min.y};
|
| 65 |
+
return extent.x * extent.y;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
void invalidate() {
|
| 69 |
+
min = {FLT_MAX, FLT_MAX};
|
| 70 |
+
max = {FLT_MIN, FLT_MIN};
|
| 71 |
+
}
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
// 32-byte BVH node struct
|
| 75 |
+
struct alignas(32) BVHNode {
|
| 76 |
+
AABB bbox; // 16
|
| 77 |
+
int start = 0, end = 0; // 8
|
| 78 |
+
int left, right;
|
| 79 |
+
|
| 80 |
+
int num_triangles() const { return end - start; }
|
| 81 |
+
|
| 82 |
+
bool is_leaf() const { return left == -1 && right == -1; }
|
| 83 |
+
|
| 84 |
+
float calculate_node_cost() {
|
| 85 |
+
float area = bbox.area();
|
| 86 |
+
return num_triangles() * area;
|
| 87 |
+
}
|
| 88 |
+
};
|
| 89 |
+
|
| 90 |
+
class BVH {
|
| 91 |
+
public:
|
| 92 |
+
BVH() = default;
|
| 93 |
+
BVH(BVH &&other) noexcept;
|
| 94 |
+
BVH(const BVH &other);
|
| 95 |
+
BVH &operator=(const BVH &other);
|
| 96 |
+
BVH &operator=(BVH &&other) noexcept;
|
| 97 |
+
BVH(Triangle *tri, int *actual_idx, const size_t &num_indices);
|
| 98 |
+
~BVH();
|
| 99 |
+
|
| 100 |
+
std::vector<int> Intersect(Triangle &triangle);
|
| 101 |
+
|
| 102 |
+
private:
|
| 103 |
+
void Subdivide(unsigned int node_idx, unsigned int &nodePtr,
|
| 104 |
+
AABB ¢roidBounds);
|
| 105 |
+
void UpdateNodeBounds(unsigned int nodeIdx, AABB ¢roidBounds);
|
| 106 |
+
float FindBestSplitPlane(BVHNode &node, int &axis, int &splitPos,
|
| 107 |
+
AABB ¢roidBounds);
|
| 108 |
+
|
| 109 |
+
public:
|
| 110 |
+
int *triIdx = nullptr;
|
| 111 |
+
int *actualIdx = nullptr;
|
| 112 |
+
unsigned int triCount;
|
| 113 |
+
unsigned int nodesUsed;
|
| 114 |
+
BVHNode *bvhNode = nullptr;
|
| 115 |
+
Triangle *triangle = nullptr;
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
} // namespace UVUnwrapper
|
uv_unwrapper/uv_unwrapper/csrc/common.h
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <array>
|
| 4 |
+
#include <cmath>
|
| 5 |
+
#include <iostream>
|
| 6 |
+
#include <stdexcept>
|
| 7 |
+
|
| 8 |
+
const float EPSILON = 1e-7f;
|
| 9 |
+
|
| 10 |
+
// Structure to represent a 2D point or vector
|
| 11 |
+
union alignas(8) uv_float2 {
|
| 12 |
+
struct {
|
| 13 |
+
float x, y;
|
| 14 |
+
};
|
| 15 |
+
|
| 16 |
+
float data[2];
|
| 17 |
+
|
| 18 |
+
float &operator[](size_t idx) {
|
| 19 |
+
if (idx > 1)
|
| 20 |
+
throw std::runtime_error("bad index");
|
| 21 |
+
return data[idx];
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
const float &operator[](size_t idx) const {
|
| 25 |
+
if (idx > 1)
|
| 26 |
+
throw std::runtime_error("bad index");
|
| 27 |
+
return data[idx];
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
bool operator==(const uv_float2 &rhs) const {
|
| 31 |
+
return x == rhs.x && y == rhs.y;
|
| 32 |
+
}
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
// Do not align as this is specifically tweaked for BVHNode
|
| 36 |
+
union uv_float3 {
|
| 37 |
+
struct {
|
| 38 |
+
float x, y, z;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
float data[3];
|
| 42 |
+
|
| 43 |
+
float &operator[](size_t idx) {
|
| 44 |
+
if (idx > 3)
|
| 45 |
+
throw std::runtime_error("bad index");
|
| 46 |
+
return data[idx];
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
const float &operator[](size_t idx) const {
|
| 50 |
+
if (idx > 3)
|
| 51 |
+
throw std::runtime_error("bad index");
|
| 52 |
+
return data[idx];
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
bool operator==(const uv_float3 &rhs) const {
|
| 56 |
+
return x == rhs.x && y == rhs.y && z == rhs.z;
|
| 57 |
+
}
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
union alignas(16) uv_float4 {
|
| 61 |
+
struct {
|
| 62 |
+
float x, y, z, w;
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
float data[4];
|
| 66 |
+
|
| 67 |
+
float &operator[](size_t idx) {
|
| 68 |
+
if (idx > 3)
|
| 69 |
+
throw std::runtime_error("bad index");
|
| 70 |
+
return data[idx];
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
const float &operator[](size_t idx) const {
|
| 74 |
+
if (idx > 3)
|
| 75 |
+
throw std::runtime_error("bad index");
|
| 76 |
+
return data[idx];
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
bool operator==(const uv_float4 &rhs) const {
|
| 80 |
+
return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
|
| 81 |
+
}
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
union alignas(8) uv_int2 {
|
| 85 |
+
struct {
|
| 86 |
+
int x, y;
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
+
int data[2];
|
| 90 |
+
|
| 91 |
+
int &operator[](size_t idx) {
|
| 92 |
+
if (idx > 1)
|
| 93 |
+
throw std::runtime_error("bad index");
|
| 94 |
+
return data[idx];
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
const int &operator[](size_t idx) const {
|
| 98 |
+
if (idx > 1)
|
| 99 |
+
throw std::runtime_error("bad index");
|
| 100 |
+
return data[idx];
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
bool operator==(const uv_int2 &rhs) const { return x == rhs.x && y == rhs.y; }
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
union alignas(4) uv_int3 {
|
| 107 |
+
struct {
|
| 108 |
+
int x, y, z;
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
int data[3];
|
| 112 |
+
|
| 113 |
+
int &operator[](size_t idx) {
|
| 114 |
+
if (idx > 2)
|
| 115 |
+
throw std::runtime_error("bad index");
|
| 116 |
+
return data[idx];
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
const int &operator[](size_t idx) const {
|
| 120 |
+
if (idx > 2)
|
| 121 |
+
throw std::runtime_error("bad index");
|
| 122 |
+
return data[idx];
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
bool operator==(const uv_int3 &rhs) const {
|
| 126 |
+
return x == rhs.x && y == rhs.y && z == rhs.z;
|
| 127 |
+
}
|
| 128 |
+
};
|
| 129 |
+
|
| 130 |
+
union alignas(16) uv_int4 {
|
| 131 |
+
struct {
|
| 132 |
+
int x, y, z, w;
|
| 133 |
+
};
|
| 134 |
+
|
| 135 |
+
int data[4];
|
| 136 |
+
|
| 137 |
+
int &operator[](size_t idx) {
|
| 138 |
+
if (idx > 3)
|
| 139 |
+
throw std::runtime_error("bad index");
|
| 140 |
+
return data[idx];
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
const int &operator[](size_t idx) const {
|
| 144 |
+
if (idx > 3)
|
| 145 |
+
throw std::runtime_error("bad index");
|
| 146 |
+
return data[idx];
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
bool operator==(const uv_int4 &rhs) const {
|
| 150 |
+
return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
|
| 151 |
+
}
|
| 152 |
+
};
|
| 153 |
+
|
| 154 |
+
inline float calc_mean(float a, float b, float c) { return (a + b + c) / 3; }
|
| 155 |
+
|
| 156 |
+
// Create a triangle centroid
|
| 157 |
+
inline uv_float2 triangle_centroid(const uv_float2 &v0, const uv_float2 &v1,
|
| 158 |
+
const uv_float2 &v2) {
|
| 159 |
+
return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y)};
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
inline uv_float3 triangle_centroid(const uv_float3 &v0, const uv_float3 &v1,
|
| 163 |
+
const uv_float3 &v2) {
|
| 164 |
+
return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y),
|
| 165 |
+
calc_mean(v0.z, v1.z, v2.z)};
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// Helper functions for vector math
|
| 169 |
+
inline uv_float2 operator-(const uv_float2 &a, const uv_float2 &b) {
|
| 170 |
+
return {a.x - b.x, a.y - b.y};
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
inline uv_float3 operator-(const uv_float3 &a, const uv_float3 &b) {
|
| 174 |
+
return {a.x - b.x, a.y - b.y, a.z - b.z};
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
inline uv_float2 operator+(const uv_float2 &a, const uv_float2 &b) {
|
| 178 |
+
return {a.x + b.x, a.y + b.y};
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
inline uv_float3 operator+(const uv_float3 &a, const uv_float3 &b) {
|
| 182 |
+
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
inline uv_float2 operator*(const uv_float2 &a, float scalar) {
|
| 186 |
+
return {a.x * scalar, a.y * scalar};
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
inline uv_float3 operator*(const uv_float3 &a, float scalar) {
|
| 190 |
+
return {a.x * scalar, a.y * scalar, a.z * scalar};
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
inline float dot(const uv_float2 &a, const uv_float2 &b) {
|
| 194 |
+
return a.x * b.x + a.y * b.y;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
inline float dot(const uv_float3 &a, const uv_float3 &b) {
|
| 198 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
inline float cross(const uv_float2 &a, const uv_float2 &b) {
|
| 202 |
+
return a.x * b.y - a.y * b.x;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
inline uv_float3 cross(const uv_float3 &a, const uv_float3 &b) {
|
| 206 |
+
return {a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x};
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
inline uv_float2 abs_vec(const uv_float2 &v) {
|
| 210 |
+
return {std::abs(v.x), std::abs(v.y)};
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
inline uv_float2 min_vec(const uv_float2 &a, const uv_float2 &b) {
|
| 214 |
+
return {std::min(a.x, b.x), std::min(a.y, b.y)};
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
inline uv_float2 max_vec(const uv_float2 &a, const uv_float2 &b) {
|
| 218 |
+
return {std::max(a.x, b.x), std::max(a.y, b.y)};
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
inline float distance_to(const uv_float2 &a, const uv_float2 &b) {
|
| 222 |
+
return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2));
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
inline float distance_to(const uv_float3 &a, const uv_float3 &b) {
|
| 226 |
+
return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2) +
|
| 227 |
+
std::pow(a.z - b.z, 2));
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
inline uv_float2 normalize(const uv_float2 &v) {
|
| 231 |
+
float len = std::sqrt(v.x * v.x + v.y * v.y);
|
| 232 |
+
return {v.x / len, v.y / len};
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
inline uv_float3 normalize(const uv_float3 &v) {
|
| 236 |
+
float len = std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
|
| 237 |
+
return {v.x / len, v.y / len, v.z / len};
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
inline float magnitude(const uv_float3 &v) {
|
| 241 |
+
return std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
struct Matrix4 {
|
| 245 |
+
std::array<std::array<float, 4>, 4> m;
|
| 246 |
+
|
| 247 |
+
Matrix4() {
|
| 248 |
+
for (auto &row : m) {
|
| 249 |
+
row.fill(0.0f);
|
| 250 |
+
}
|
| 251 |
+
m[3][3] = 1.0f; // Identity matrix for 4th row and column
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
void set(float m00, float m01, float m02, float m03, float m10, float m11,
|
| 255 |
+
float m12, float m13, float m20, float m21, float m22, float m23,
|
| 256 |
+
float m30, float m31, float m32, float m33) {
|
| 257 |
+
m[0][0] = m00;
|
| 258 |
+
m[0][1] = m01;
|
| 259 |
+
m[0][2] = m02;
|
| 260 |
+
m[0][3] = m03;
|
| 261 |
+
m[1][0] = m10;
|
| 262 |
+
m[1][1] = m11;
|
| 263 |
+
m[1][2] = m12;
|
| 264 |
+
m[1][3] = m13;
|
| 265 |
+
m[2][0] = m20;
|
| 266 |
+
m[2][1] = m21;
|
| 267 |
+
m[2][2] = m22;
|
| 268 |
+
m[2][3] = m23;
|
| 269 |
+
m[3][0] = m30;
|
| 270 |
+
m[3][1] = m31;
|
| 271 |
+
m[3][2] = m32;
|
| 272 |
+
m[3][3] = m33;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
float determinant() const {
|
| 276 |
+
return m[0][3] * m[1][2] * m[2][1] * m[3][0] -
|
| 277 |
+
m[0][2] * m[1][3] * m[2][1] * m[3][0] -
|
| 278 |
+
m[0][3] * m[1][1] * m[2][2] * m[3][0] +
|
| 279 |
+
m[0][1] * m[1][3] * m[2][2] * m[3][0] +
|
| 280 |
+
m[0][2] * m[1][1] * m[2][3] * m[3][0] -
|
| 281 |
+
m[0][1] * m[1][2] * m[2][3] * m[3][0] -
|
| 282 |
+
m[0][3] * m[1][2] * m[2][0] * m[3][1] +
|
| 283 |
+
m[0][2] * m[1][3] * m[2][0] * m[3][1] +
|
| 284 |
+
m[0][3] * m[1][0] * m[2][2] * m[3][1] -
|
| 285 |
+
m[0][0] * m[1][3] * m[2][2] * m[3][1] -
|
| 286 |
+
m[0][2] * m[1][0] * m[2][3] * m[3][1] +
|
| 287 |
+
m[0][0] * m[1][2] * m[2][3] * m[3][1] +
|
| 288 |
+
m[0][3] * m[1][1] * m[2][0] * m[3][2] -
|
| 289 |
+
m[0][1] * m[1][3] * m[2][0] * m[3][2] -
|
| 290 |
+
m[0][3] * m[1][0] * m[2][1] * m[3][2] +
|
| 291 |
+
m[0][0] * m[1][3] * m[2][1] * m[3][2] +
|
| 292 |
+
m[0][1] * m[1][0] * m[2][3] * m[3][2] -
|
| 293 |
+
m[0][0] * m[1][1] * m[2][3] * m[3][2] -
|
| 294 |
+
m[0][2] * m[1][1] * m[2][0] * m[3][3] +
|
| 295 |
+
m[0][1] * m[1][2] * m[2][0] * m[3][3] +
|
| 296 |
+
m[0][2] * m[1][0] * m[2][1] * m[3][3] -
|
| 297 |
+
m[0][0] * m[1][2] * m[2][1] * m[3][3] -
|
| 298 |
+
m[0][1] * m[1][0] * m[2][2] * m[3][3] +
|
| 299 |
+
m[0][0] * m[1][1] * m[2][2] * m[3][3];
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
Matrix4 operator*(const Matrix4 &other) const {
|
| 303 |
+
Matrix4 result;
|
| 304 |
+
for (int row = 0; row < 4; ++row) {
|
| 305 |
+
for (int col = 0; col < 4; ++col) {
|
| 306 |
+
result.m[row][col] =
|
| 307 |
+
m[row][0] * other.m[0][col] + m[row][1] * other.m[1][col] +
|
| 308 |
+
m[row][2] * other.m[2][col] + m[row][3] * other.m[3][col];
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
return result;
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
Matrix4 operator*(float scalar) const {
|
| 315 |
+
Matrix4 result = *this;
|
| 316 |
+
for (auto &row : result.m) {
|
| 317 |
+
for (auto &element : row) {
|
| 318 |
+
element *= scalar;
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
return result;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
Matrix4 operator+(const Matrix4 &other) const {
|
| 325 |
+
Matrix4 result;
|
| 326 |
+
for (int i = 0; i < 4; ++i) {
|
| 327 |
+
for (int j = 0; j < 4; ++j) {
|
| 328 |
+
result.m[i][j] = m[i][j] + other.m[i][j];
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
return result;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
Matrix4 operator-(const Matrix4 &other) const {
|
| 335 |
+
Matrix4 result;
|
| 336 |
+
for (int i = 0; i < 4; ++i) {
|
| 337 |
+
for (int j = 0; j < 4; ++j) {
|
| 338 |
+
result.m[i][j] = m[i][j] - other.m[i][j];
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
return result;
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
float trace() const { return m[0][0] + m[1][1] + m[2][2] + m[3][3]; }
|
| 345 |
+
|
| 346 |
+
Matrix4 identity() const {
|
| 347 |
+
Matrix4 identity;
|
| 348 |
+
identity.set(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1);
|
| 349 |
+
return identity;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
Matrix4 power(int exp) const {
|
| 353 |
+
if (exp == 0)
|
| 354 |
+
return identity();
|
| 355 |
+
if (exp == 1)
|
| 356 |
+
return *this;
|
| 357 |
+
|
| 358 |
+
Matrix4 result = *this;
|
| 359 |
+
for (int i = 1; i < exp; ++i) {
|
| 360 |
+
result = result * (*this);
|
| 361 |
+
}
|
| 362 |
+
return result;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
void print() {
|
| 366 |
+
// Print all entries in 4 rows with 4 columns
|
| 367 |
+
for (int i = 0; i < 4; ++i) {
|
| 368 |
+
for (int j = 0; j < 4; ++j) {
|
| 369 |
+
std::cout << m[i][j] << " ";
|
| 370 |
+
}
|
| 371 |
+
std::cout << std::endl;
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
bool invert() {
|
| 376 |
+
double inv[16], det;
|
| 377 |
+
double mArr[16];
|
| 378 |
+
|
| 379 |
+
// Convert the matrix to a 1D array for easier manipulation
|
| 380 |
+
for (int i = 0; i < 4; ++i) {
|
| 381 |
+
for (int j = 0; j < 4; ++j) {
|
| 382 |
+
mArr[i * 4 + j] = static_cast<double>(m[i][j]);
|
| 383 |
+
}
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
inv[0] = mArr[5] * mArr[10] * mArr[15] - mArr[5] * mArr[11] * mArr[14] -
|
| 387 |
+
mArr[9] * mArr[6] * mArr[15] + mArr[9] * mArr[7] * mArr[14] +
|
| 388 |
+
mArr[13] * mArr[6] * mArr[11] - mArr[13] * mArr[7] * mArr[10];
|
| 389 |
+
|
| 390 |
+
inv[4] = -mArr[4] * mArr[10] * mArr[15] + mArr[4] * mArr[11] * mArr[14] +
|
| 391 |
+
mArr[8] * mArr[6] * mArr[15] - mArr[8] * mArr[7] * mArr[14] -
|
| 392 |
+
mArr[12] * mArr[6] * mArr[11] + mArr[12] * mArr[7] * mArr[10];
|
| 393 |
+
|
| 394 |
+
inv[8] = mArr[4] * mArr[9] * mArr[15] - mArr[4] * mArr[11] * mArr[13] -
|
| 395 |
+
mArr[8] * mArr[5] * mArr[15] + mArr[8] * mArr[7] * mArr[13] +
|
| 396 |
+
mArr[12] * mArr[5] * mArr[11] - mArr[12] * mArr[7] * mArr[9];
|
| 397 |
+
|
| 398 |
+
inv[12] = -mArr[4] * mArr[9] * mArr[14] + mArr[4] * mArr[10] * mArr[13] +
|
| 399 |
+
mArr[8] * mArr[5] * mArr[14] - mArr[8] * mArr[6] * mArr[13] -
|
| 400 |
+
mArr[12] * mArr[5] * mArr[10] + mArr[12] * mArr[6] * mArr[9];
|
| 401 |
+
|
| 402 |
+
inv[1] = -mArr[1] * mArr[10] * mArr[15] + mArr[1] * mArr[11] * mArr[14] +
|
| 403 |
+
mArr[9] * mArr[2] * mArr[15] - mArr[9] * mArr[3] * mArr[14] -
|
| 404 |
+
mArr[13] * mArr[2] * mArr[11] + mArr[13] * mArr[3] * mArr[10];
|
| 405 |
+
|
| 406 |
+
inv[5] = mArr[0] * mArr[10] * mArr[15] - mArr[0] * mArr[11] * mArr[14] -
|
| 407 |
+
mArr[8] * mArr[2] * mArr[15] + mArr[8] * mArr[3] * mArr[14] +
|
| 408 |
+
mArr[12] * mArr[2] * mArr[11] - mArr[12] * mArr[3] * mArr[10];
|
| 409 |
+
|
| 410 |
+
inv[9] = -mArr[0] * mArr[9] * mArr[15] + mArr[0] * mArr[11] * mArr[13] +
|
| 411 |
+
mArr[8] * mArr[1] * mArr[15] - mArr[8] * mArr[3] * mArr[13] -
|
| 412 |
+
mArr[12] * mArr[1] * mArr[11] + mArr[12] * mArr[3] * mArr[9];
|
| 413 |
+
|
| 414 |
+
inv[13] = mArr[0] * mArr[9] * mArr[14] - mArr[0] * mArr[10] * mArr[13] -
|
| 415 |
+
mArr[8] * mArr[1] * mArr[14] + mArr[8] * mArr[2] * mArr[13] +
|
| 416 |
+
mArr[12] * mArr[1] * mArr[10] - mArr[12] * mArr[2] * mArr[9];
|
| 417 |
+
|
| 418 |
+
inv[2] = mArr[1] * mArr[6] * mArr[15] - mArr[1] * mArr[7] * mArr[14] -
|
| 419 |
+
mArr[5] * mArr[2] * mArr[15] + mArr[5] * mArr[3] * mArr[14] +
|
| 420 |
+
mArr[13] * mArr[2] * mArr[7] - mArr[13] * mArr[3] * mArr[6];
|
| 421 |
+
|
| 422 |
+
inv[6] = -mArr[0] * mArr[6] * mArr[15] + mArr[0] * mArr[7] * mArr[14] +
|
| 423 |
+
mArr[4] * mArr[2] * mArr[15] - mArr[4] * mArr[3] * mArr[14] -
|
| 424 |
+
mArr[12] * mArr[2] * mArr[7] + mArr[12] * mArr[3] * mArr[6];
|
| 425 |
+
|
| 426 |
+
inv[10] = mArr[0] * mArr[5] * mArr[15] - mArr[0] * mArr[7] * mArr[13] -
|
| 427 |
+
mArr[4] * mArr[1] * mArr[15] + mArr[4] * mArr[3] * mArr[13] +
|
| 428 |
+
mArr[12] * mArr[1] * mArr[7] - mArr[12] * mArr[3] * mArr[5];
|
| 429 |
+
|
| 430 |
+
inv[14] = -mArr[0] * mArr[5] * mArr[14] + mArr[0] * mArr[6] * mArr[13] +
|
| 431 |
+
mArr[4] * mArr[1] * mArr[14] - mArr[4] * mArr[2] * mArr[13] -
|
| 432 |
+
mArr[12] * mArr[1] * mArr[6] + mArr[12] * mArr[2] * mArr[5];
|
| 433 |
+
|
| 434 |
+
inv[3] = -mArr[1] * mArr[6] * mArr[11] + mArr[1] * mArr[7] * mArr[10] +
|
| 435 |
+
mArr[5] * mArr[2] * mArr[11] - mArr[5] * mArr[3] * mArr[10] -
|
| 436 |
+
mArr[9] * mArr[2] * mArr[7] + mArr[9] * mArr[3] * mArr[6];
|
| 437 |
+
|
| 438 |
+
inv[7] = mArr[0] * mArr[6] * mArr[11] - mArr[0] * mArr[7] * mArr[10] -
|
| 439 |
+
mArr[4] * mArr[2] * mArr[11] + mArr[4] * mArr[3] * mArr[10] +
|
| 440 |
+
mArr[8] * mArr[2] * mArr[7] - mArr[8] * mArr[3] * mArr[6];
|
| 441 |
+
|
| 442 |
+
inv[11] = -mArr[0] * mArr[5] * mArr[11] + mArr[0] * mArr[7] * mArr[9] +
|
| 443 |
+
mArr[4] * mArr[1] * mArr[11] - mArr[4] * mArr[3] * mArr[9] -
|
| 444 |
+
mArr[8] * mArr[1] * mArr[7] + mArr[8] * mArr[3] * mArr[5];
|
| 445 |
+
|
| 446 |
+
inv[15] = mArr[0] * mArr[5] * mArr[10] - mArr[0] * mArr[6] * mArr[9] -
|
| 447 |
+
mArr[4] * mArr[1] * mArr[10] + mArr[4] * mArr[2] * mArr[9] +
|
| 448 |
+
mArr[8] * mArr[1] * mArr[6] - mArr[8] * mArr[2] * mArr[5];
|
| 449 |
+
|
| 450 |
+
det = mArr[0] * inv[0] + mArr[1] * inv[4] + mArr[2] * inv[8] +
|
| 451 |
+
mArr[3] * inv[12];
|
| 452 |
+
|
| 453 |
+
if (fabs(det) < 1e-6) {
|
| 454 |
+
return false;
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
det = 1.0 / det;
|
| 458 |
+
|
| 459 |
+
for (int i = 0; i < 16; i++) {
|
| 460 |
+
inv[i] *= det;
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
// Convert the 1D array back to the 4x4 matrix
|
| 464 |
+
for (int i = 0; i < 4; ++i) {
|
| 465 |
+
for (int j = 0; j < 4; ++j) {
|
| 466 |
+
m[i][j] = static_cast<float>(inv[i * 4 + j]);
|
| 467 |
+
}
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
return true;
|
| 471 |
+
}
|
| 472 |
+
};
|
| 473 |
+
|
| 474 |
+
inline void apply_matrix4(uv_float3 &v, const Matrix4 matrix) {
|
| 475 |
+
float newX = v.x * matrix.m[0][0] + v.y * matrix.m[0][1] +
|
| 476 |
+
v.z * matrix.m[0][2] + matrix.m[0][3];
|
| 477 |
+
float newY = v.x * matrix.m[1][0] + v.y * matrix.m[1][1] +
|
| 478 |
+
v.z * matrix.m[1][2] + matrix.m[1][3];
|
| 479 |
+
float newZ = v.x * matrix.m[2][0] + v.y * matrix.m[2][1] +
|
| 480 |
+
v.z * matrix.m[2][2] + matrix.m[2][3];
|
| 481 |
+
float w = v.x * matrix.m[3][0] + v.y * matrix.m[3][1] + v.z * matrix.m[3][2] +
|
| 482 |
+
matrix.m[3][3];
|
| 483 |
+
|
| 484 |
+
if (std::fabs(w) > EPSILON) {
|
| 485 |
+
newX /= w;
|
| 486 |
+
newY /= w;
|
| 487 |
+
newZ /= w;
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
v.x = newX;
|
| 491 |
+
v.y = newY;
|
| 492 |
+
v.z = newZ;
|
| 493 |
+
}
|
uv_unwrapper/uv_unwrapper/csrc/intersect.cpp
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "intersect.h"
|
| 2 |
+
#include "bvh.h"
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <cmath>
|
| 5 |
+
#include <iostream>
|
| 6 |
+
#include <stdexcept>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
bool triangle_aabb_intersection(const uv_float2 &aabbMin,
|
| 10 |
+
const uv_float2 &aabbMax, const uv_float2 &v0,
|
| 11 |
+
const uv_float2 &v1, const uv_float2 &v2) {
|
| 12 |
+
// Convert the min and max aabb defintion to left, right, top, bottom
|
| 13 |
+
float l = aabbMin.x;
|
| 14 |
+
float r = aabbMax.x;
|
| 15 |
+
float t = aabbMin.y;
|
| 16 |
+
float b = aabbMax.y;
|
| 17 |
+
|
| 18 |
+
int b0 = ((v0.x > l) ? 1 : 0) | ((v0.y > t) ? 2 : 0) | ((v0.x > r) ? 4 : 0) |
|
| 19 |
+
((v0.y > b) ? 8 : 0);
|
| 20 |
+
if (b0 == 3)
|
| 21 |
+
return true;
|
| 22 |
+
|
| 23 |
+
int b1 = ((v1.x > l) ? 1 : 0) | ((v1.y > t) ? 2 : 0) | ((v1.x > r) ? 4 : 0) |
|
| 24 |
+
((v1.y > b) ? 8 : 0);
|
| 25 |
+
if (b1 == 3)
|
| 26 |
+
return true;
|
| 27 |
+
|
| 28 |
+
int b2 = ((v2.x > l) ? 1 : 0) | ((v2.y > t) ? 2 : 0) | ((v2.x > r) ? 4 : 0) |
|
| 29 |
+
((v2.y > b) ? 8 : 0);
|
| 30 |
+
if (b2 == 3)
|
| 31 |
+
return true;
|
| 32 |
+
|
| 33 |
+
float m, c, s;
|
| 34 |
+
|
| 35 |
+
int i0 = b0 ^ b1;
|
| 36 |
+
if (i0 != 0) {
|
| 37 |
+
if (v1.x != v0.x) {
|
| 38 |
+
m = (v1.y - v0.y) / (v1.x - v0.x);
|
| 39 |
+
c = v0.y - (m * v0.x);
|
| 40 |
+
if (i0 & 1) {
|
| 41 |
+
s = m * l + c;
|
| 42 |
+
if (s >= t && s <= b)
|
| 43 |
+
return true;
|
| 44 |
+
}
|
| 45 |
+
if (i0 & 2) {
|
| 46 |
+
s = (t - c) / m;
|
| 47 |
+
if (s >= l && s <= r)
|
| 48 |
+
return true;
|
| 49 |
+
}
|
| 50 |
+
if (i0 & 4) {
|
| 51 |
+
s = m * r + c;
|
| 52 |
+
if (s >= t && s <= b)
|
| 53 |
+
return true;
|
| 54 |
+
}
|
| 55 |
+
if (i0 & 8) {
|
| 56 |
+
s = (b - c) / m;
|
| 57 |
+
if (s >= l && s <= r)
|
| 58 |
+
return true;
|
| 59 |
+
}
|
| 60 |
+
} else {
|
| 61 |
+
if (l == v0.x || r == v0.x)
|
| 62 |
+
return true;
|
| 63 |
+
if (v0.x > l && v0.x < r)
|
| 64 |
+
return true;
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
int i1 = b1 ^ b2;
|
| 69 |
+
if (i1 != 0) {
|
| 70 |
+
if (v2.x != v1.x) {
|
| 71 |
+
m = (v2.y - v1.y) / (v2.x - v1.x);
|
| 72 |
+
c = v1.y - (m * v1.x);
|
| 73 |
+
if (i1 & 1) {
|
| 74 |
+
s = m * l + c;
|
| 75 |
+
if (s >= t && s <= b)
|
| 76 |
+
return true;
|
| 77 |
+
}
|
| 78 |
+
if (i1 & 2) {
|
| 79 |
+
s = (t - c) / m;
|
| 80 |
+
if (s >= l && s <= r)
|
| 81 |
+
return true;
|
| 82 |
+
}
|
| 83 |
+
if (i1 & 4) {
|
| 84 |
+
s = m * r + c;
|
| 85 |
+
if (s >= t && s <= b)
|
| 86 |
+
return true;
|
| 87 |
+
}
|
| 88 |
+
if (i1 & 8) {
|
| 89 |
+
s = (b - c) / m;
|
| 90 |
+
if (s >= l && s <= r)
|
| 91 |
+
return true;
|
| 92 |
+
}
|
| 93 |
+
} else {
|
| 94 |
+
if (l == v1.x || r == v1.x)
|
| 95 |
+
return true;
|
| 96 |
+
if (v1.x > l && v1.x < r)
|
| 97 |
+
return true;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
int i2 = b0 ^ b2;
|
| 102 |
+
if (i2 != 0) {
|
| 103 |
+
if (v2.x != v0.x) {
|
| 104 |
+
m = (v2.y - v0.y) / (v2.x - v0.x);
|
| 105 |
+
c = v0.y - (m * v0.x);
|
| 106 |
+
if (i2 & 1) {
|
| 107 |
+
s = m * l + c;
|
| 108 |
+
if (s >= t && s <= b)
|
| 109 |
+
return true;
|
| 110 |
+
}
|
| 111 |
+
if (i2 & 2) {
|
| 112 |
+
s = (t - c) / m;
|
| 113 |
+
if (s >= l && s <= r)
|
| 114 |
+
return true;
|
| 115 |
+
}
|
| 116 |
+
if (i2 & 4) {
|
| 117 |
+
s = m * r + c;
|
| 118 |
+
if (s >= t && s <= b)
|
| 119 |
+
return true;
|
| 120 |
+
}
|
| 121 |
+
if (i2 & 8) {
|
| 122 |
+
s = (b - c) / m;
|
| 123 |
+
if (s >= l && s <= r)
|
| 124 |
+
return true;
|
| 125 |
+
}
|
| 126 |
+
} else {
|
| 127 |
+
if (l == v0.x || r == v0.x)
|
| 128 |
+
return true;
|
| 129 |
+
if (v0.x > l && v0.x < r)
|
| 130 |
+
return true;
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
// Bounding box check
|
| 135 |
+
float tbb_l = std::min(v0.x, std::min(v1.x, v2.x));
|
| 136 |
+
float tbb_t = std::min(v0.y, std::min(v1.y, v2.y));
|
| 137 |
+
float tbb_r = std::max(v0.x, std::max(v1.x, v2.x));
|
| 138 |
+
float tbb_b = std::max(v0.y, std::max(v1.y, v2.y));
|
| 139 |
+
|
| 140 |
+
if (tbb_l <= l && tbb_r >= r && tbb_t <= t && tbb_b >= b) {
|
| 141 |
+
float v0x = v2.x - v0.x;
|
| 142 |
+
float v0y = v2.y - v0.y;
|
| 143 |
+
float v1x = v1.x - v0.x;
|
| 144 |
+
float v1y = v1.y - v0.y;
|
| 145 |
+
float v2x, v2y;
|
| 146 |
+
|
| 147 |
+
float dot00, dot01, dot02, dot11, dot12, invDenom, u, v;
|
| 148 |
+
|
| 149 |
+
// Top-left corner
|
| 150 |
+
v2x = l - v0.x;
|
| 151 |
+
v2y = t - v0.y;
|
| 152 |
+
|
| 153 |
+
dot00 = v0x * v0x + v0y * v0y;
|
| 154 |
+
dot01 = v0x * v1x + v0y * v1y;
|
| 155 |
+
dot02 = v0x * v2x + v0y * v2y;
|
| 156 |
+
dot11 = v1x * v1x + v1y * v1y;
|
| 157 |
+
dot12 = v1x * v2x + v1y * v2y;
|
| 158 |
+
|
| 159 |
+
invDenom = 1.0f / (dot00 * dot11 - dot01 * dot01);
|
| 160 |
+
u = (dot11 * dot02 - dot01 * dot12) * invDenom;
|
| 161 |
+
v = (dot00 * dot12 - dot01 * dot02) * invDenom;
|
| 162 |
+
|
| 163 |
+
if (u >= 0 && v >= 0 && (u + v) <= 1)
|
| 164 |
+
return true;
|
| 165 |
+
|
| 166 |
+
// Bottom-left corner
|
| 167 |
+
v2x = l - v0.x;
|
| 168 |
+
v2y = b - v0.y;
|
| 169 |
+
|
| 170 |
+
dot02 = v0x * v2x + v0y * v2y;
|
| 171 |
+
dot12 = v1x * v2x + v1y * v2y;
|
| 172 |
+
|
| 173 |
+
u = (dot11 * dot02 - dot01 * dot12) * invDenom;
|
| 174 |
+
v = (dot00 * dot12 - dot01 * dot02) * invDenom;
|
| 175 |
+
|
| 176 |
+
if (u >= 0 && v >= 0 && (u + v) <= 1)
|
| 177 |
+
return true;
|
| 178 |
+
|
| 179 |
+
// Bottom-right corner
|
| 180 |
+
v2x = r - v0.x;
|
| 181 |
+
v2y = b - v0.y;
|
| 182 |
+
|
| 183 |
+
dot02 = v0x * v2x + v0y * v2y;
|
| 184 |
+
dot12 = v1x * v2x + v1y * v2y;
|
| 185 |
+
|
| 186 |
+
u = (dot11 * dot02 - dot01 * dot12) * invDenom;
|
| 187 |
+
v = (dot00 * dot12 - dot01 * dot02) * invDenom;
|
| 188 |
+
|
| 189 |
+
if (u >= 0 && v >= 0 && (u + v) <= 1)
|
| 190 |
+
return true;
|
| 191 |
+
|
| 192 |
+
// Top-right corner
|
| 193 |
+
v2x = r - v0.x;
|
| 194 |
+
v2y = t - v0.y;
|
| 195 |
+
|
| 196 |
+
dot02 = v0x * v2x + v0y * v2y;
|
| 197 |
+
dot12 = v1x * v2x + v1y * v2y;
|
| 198 |
+
|
| 199 |
+
u = (dot11 * dot02 - dot01 * dot12) * invDenom;
|
| 200 |
+
v = (dot00 * dot12 - dot01 * dot02) * invDenom;
|
| 201 |
+
|
| 202 |
+
if (u >= 0 && v >= 0 && (u + v) <= 1)
|
| 203 |
+
return true;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
return false;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
void tri_winding(uv_float2 &a, uv_float2 &b, uv_float2 &c) {
|
| 210 |
+
float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
|
| 211 |
+
|
| 212 |
+
// If the determinant is negative, the triangle is oriented clockwise
|
| 213 |
+
if (det < 0) {
|
| 214 |
+
// Swap vertices b and c to ensure counter-clockwise winding
|
| 215 |
+
std::swap(b, c);
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
struct Triangle {
|
| 220 |
+
uv_float3 a, b, c;
|
| 221 |
+
|
| 222 |
+
Triangle(const uv_float2 &p1, const uv_float2 &q1, const uv_float2 &r1)
|
| 223 |
+
: a({p1.x, p1.y, 0}), b({q1.x, q1.y, 0}), c({r1.x, r1.y, 0}) {}
|
| 224 |
+
|
| 225 |
+
Triangle(const uv_float3 &p1, const uv_float3 &q1, const uv_float3 &r1)
|
| 226 |
+
: a(p1), b(q1), c(r1) {}
|
| 227 |
+
|
| 228 |
+
void getNormal(uv_float3 &normal) const {
|
| 229 |
+
uv_float3 u = b - a;
|
| 230 |
+
uv_float3 v = c - a;
|
| 231 |
+
normal = normalize(cross(u, v));
|
| 232 |
+
}
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
bool isTriDegenerated(const Triangle &tri) {
|
| 236 |
+
uv_float3 u = tri.a - tri.b;
|
| 237 |
+
uv_float3 v = tri.a - tri.c;
|
| 238 |
+
uv_float3 cr = cross(u, v);
|
| 239 |
+
return fabs(cr.x) < EPSILON && fabs(cr.y) < EPSILON && fabs(cr.z) < EPSILON;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
int orient3D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c,
|
| 243 |
+
const uv_float3 &d) {
|
| 244 |
+
Matrix4 _matrix4;
|
| 245 |
+
_matrix4.set(a.x, a.y, a.z, 1, b.x, b.y, b.z, 1, c.x, c.y, c.z, 1, d.x, d.y,
|
| 246 |
+
d.z, 1);
|
| 247 |
+
float det = _matrix4.determinant();
|
| 248 |
+
|
| 249 |
+
if (det < -EPSILON)
|
| 250 |
+
return -1;
|
| 251 |
+
else if (det > EPSILON)
|
| 252 |
+
return 1;
|
| 253 |
+
else
|
| 254 |
+
return 0;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
int orient2D(const uv_float2 &a, const uv_float2 &b, const uv_float2 &c) {
|
| 258 |
+
float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
|
| 259 |
+
|
| 260 |
+
if (det < -EPSILON)
|
| 261 |
+
return -1;
|
| 262 |
+
else if (det > EPSILON)
|
| 263 |
+
return 1;
|
| 264 |
+
else
|
| 265 |
+
return 0;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
int orient2D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c) {
|
| 269 |
+
uv_float2 a_2d = {a.x, a.y};
|
| 270 |
+
uv_float2 b_2d = {b.x, b.y};
|
| 271 |
+
uv_float2 c_2d = {c.x, c.y};
|
| 272 |
+
return orient2D(a_2d, b_2d, c_2d);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
void permuteTriLeft(Triangle &tri) {
|
| 276 |
+
uv_float3 tmp = tri.a;
|
| 277 |
+
tri.a = tri.b;
|
| 278 |
+
tri.b = tri.c;
|
| 279 |
+
tri.c = tmp;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
void permuteTriRight(Triangle &tri) {
|
| 283 |
+
uv_float3 tmp = tri.c;
|
| 284 |
+
tri.c = tri.b;
|
| 285 |
+
tri.b = tri.a;
|
| 286 |
+
tri.a = tmp;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
void makeTriCounterClockwise(Triangle &tri) {
|
| 290 |
+
if (orient2D(tri.a, tri.b, tri.c) < 0) {
|
| 291 |
+
uv_float3 tmp = tri.c;
|
| 292 |
+
tri.c = tri.b;
|
| 293 |
+
tri.b = tmp;
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
void intersectPlane(const uv_float3 &a, const uv_float3 &b, const uv_float3 &p,
|
| 298 |
+
const uv_float3 &n, uv_float3 &target) {
|
| 299 |
+
uv_float3 u = b - a;
|
| 300 |
+
uv_float3 v = a - p;
|
| 301 |
+
float dot1 = dot(n, u);
|
| 302 |
+
float dot2 = dot(n, v);
|
| 303 |
+
u = u * (-dot2 / dot1);
|
| 304 |
+
target = a + u;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
void computeLineIntersection(const Triangle &t1, const Triangle &t2,
|
| 308 |
+
std::vector<uv_float3> &target) {
|
| 309 |
+
uv_float3 n1, n2;
|
| 310 |
+
t1.getNormal(n1);
|
| 311 |
+
t2.getNormal(n2);
|
| 312 |
+
|
| 313 |
+
int o1 = orient3D(t1.a, t1.c, t2.b, t2.a);
|
| 314 |
+
int o2 = orient3D(t1.a, t1.b, t2.c, t2.a);
|
| 315 |
+
|
| 316 |
+
uv_float3 i1, i2;
|
| 317 |
+
|
| 318 |
+
if (o1 > 0) {
|
| 319 |
+
if (o2 > 0) {
|
| 320 |
+
intersectPlane(t1.a, t1.c, t2.a, n2, i1);
|
| 321 |
+
intersectPlane(t2.a, t2.c, t1.a, n1, i2);
|
| 322 |
+
} else {
|
| 323 |
+
intersectPlane(t1.a, t1.c, t2.a, n2, i1);
|
| 324 |
+
intersectPlane(t1.a, t1.b, t2.a, n2, i2);
|
| 325 |
+
}
|
| 326 |
+
} else {
|
| 327 |
+
if (o2 > 0) {
|
| 328 |
+
intersectPlane(t2.a, t2.b, t1.a, n1, i1);
|
| 329 |
+
intersectPlane(t2.a, t2.c, t1.a, n1, i2);
|
| 330 |
+
} else {
|
| 331 |
+
intersectPlane(t2.a, t2.b, t1.a, n1, i1);
|
| 332 |
+
intersectPlane(t1.a, t1.b, t2.a, n2, i2);
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
target.push_back(i1);
|
| 337 |
+
if (distance_to(i1, i2) >= EPSILON) {
|
| 338 |
+
target.push_back(i2);
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
void makeTriAVertexAlone(Triangle &tri, int oa, int ob, int oc) {
|
| 343 |
+
// Permute a, b, c so that a is alone on its side
|
| 344 |
+
if (oa == ob) {
|
| 345 |
+
// c is alone, permute right so c becomes a
|
| 346 |
+
permuteTriRight(tri);
|
| 347 |
+
} else if (oa == oc) {
|
| 348 |
+
// b is alone, permute so b becomes a
|
| 349 |
+
permuteTriLeft(tri);
|
| 350 |
+
} else if (ob != oc) {
|
| 351 |
+
// In case a, b, c have different orientation, put a on positive side
|
| 352 |
+
if (ob > 0) {
|
| 353 |
+
permuteTriLeft(tri);
|
| 354 |
+
} else if (oc > 0) {
|
| 355 |
+
permuteTriRight(tri);
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
void makeTriAVertexPositive(Triangle &tri, const Triangle &other) {
|
| 361 |
+
int o = orient3D(other.a, other.b, other.c, tri.a);
|
| 362 |
+
if (o < 0) {
|
| 363 |
+
std::swap(tri.b, tri.c);
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
bool crossIntersect(Triangle &t1, Triangle &t2, int o1a, int o1b, int o1c,
|
| 368 |
+
std::vector<uv_float3> *target = nullptr) {
|
| 369 |
+
int o2a = orient3D(t1.a, t1.b, t1.c, t2.a);
|
| 370 |
+
int o2b = orient3D(t1.a, t1.b, t1.c, t2.b);
|
| 371 |
+
int o2c = orient3D(t1.a, t1.b, t1.c, t2.c);
|
| 372 |
+
|
| 373 |
+
if (o2a == o2b && o2a == o2c) {
|
| 374 |
+
return false;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
// Make a vertex alone on its side for both triangles
|
| 378 |
+
makeTriAVertexAlone(t1, o1a, o1b, o1c);
|
| 379 |
+
makeTriAVertexAlone(t2, o2a, o2b, o2c);
|
| 380 |
+
|
| 381 |
+
// Ensure the vertex on the positive side
|
| 382 |
+
makeTriAVertexPositive(t2, t1);
|
| 383 |
+
makeTriAVertexPositive(t1, t2);
|
| 384 |
+
|
| 385 |
+
int o1 = orient3D(t1.a, t1.b, t2.a, t2.b);
|
| 386 |
+
int o2 = orient3D(t1.a, t1.c, t2.c, t2.a);
|
| 387 |
+
|
| 388 |
+
if (o1 <= 0 && o2 <= 0) {
|
| 389 |
+
if (target) {
|
| 390 |
+
computeLineIntersection(t1, t2, *target);
|
| 391 |
+
}
|
| 392 |
+
return true;
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
return false;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
void linesIntersect2d(const uv_float3 &a1, const uv_float3 &b1,
|
| 399 |
+
const uv_float3 &a2, const uv_float3 &b2,
|
| 400 |
+
uv_float3 &target) {
|
| 401 |
+
float dx1 = a1.x - b1.x;
|
| 402 |
+
float dx2 = a2.x - b2.x;
|
| 403 |
+
float dy1 = a1.y - b1.y;
|
| 404 |
+
float dy2 = a2.y - b2.y;
|
| 405 |
+
|
| 406 |
+
float D = dx1 * dy2 - dx2 * dy1;
|
| 407 |
+
|
| 408 |
+
float n1 = a1.x * b1.y - a1.y * b1.x;
|
| 409 |
+
float n2 = a2.x * b2.y - a2.y * b2.x;
|
| 410 |
+
|
| 411 |
+
target.x = (n1 * dx2 - n2 * dx1) / D;
|
| 412 |
+
target.y = (n1 * dy2 - n2 * dy1) / D;
|
| 413 |
+
target.z = 0;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
void clipTriangle(const Triangle &t1, const Triangle &t2,
|
| 417 |
+
std::vector<uv_float3> &target) {
|
| 418 |
+
std::vector<uv_float3> clip = {t1.a, t1.b, t1.c};
|
| 419 |
+
std::vector<uv_float3> output = {t2.a, t2.b, t2.c};
|
| 420 |
+
std::vector<int> orients(output.size() * 3, 0);
|
| 421 |
+
uv_float3 inter;
|
| 422 |
+
|
| 423 |
+
for (int i = 0; i < 3; ++i) {
|
| 424 |
+
const int i_prev = (i + 2) % 3;
|
| 425 |
+
std::vector<uv_float3> input;
|
| 426 |
+
std::copy(output.begin(), output.end(), std::back_inserter(input));
|
| 427 |
+
output.clear();
|
| 428 |
+
|
| 429 |
+
for (size_t j = 0; j < input.size(); ++j) {
|
| 430 |
+
orients[j] = orient2D(clip[i_prev], clip[i], input[j]);
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
for (size_t j = 0; j < input.size(); ++j) {
|
| 434 |
+
const int j_prev = (j - 1 + input.size()) % input.size();
|
| 435 |
+
|
| 436 |
+
if (orients[j] >= 0) {
|
| 437 |
+
if (orients[j_prev] < 0) {
|
| 438 |
+
linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j],
|
| 439 |
+
inter);
|
| 440 |
+
output.push_back({inter.x, inter.y, inter.z});
|
| 441 |
+
}
|
| 442 |
+
output.push_back({input[j].x, input[j].y, input[j].z});
|
| 443 |
+
} else if (orients[j_prev] >= 0) {
|
| 444 |
+
linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j], inter);
|
| 445 |
+
output.push_back({inter.x, inter.y, inter.z});
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
// Clear duplicated points
|
| 451 |
+
for (const auto &point : output) {
|
| 452 |
+
int j = 0;
|
| 453 |
+
bool sameFound = false;
|
| 454 |
+
while (!sameFound && j < target.size()) {
|
| 455 |
+
sameFound = distance_to(point, target[j]) <= 1e-6;
|
| 456 |
+
j++;
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
if (!sameFound) {
|
| 460 |
+
target.push_back(point);
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
bool intersectionTypeR1(const Triangle &t1, const Triangle &t2) {
|
| 466 |
+
const uv_float3 &p1 = t1.a;
|
| 467 |
+
const uv_float3 &q1 = t1.b;
|
| 468 |
+
const uv_float3 &r1 = t1.c;
|
| 469 |
+
const uv_float3 &p2 = t2.a;
|
| 470 |
+
const uv_float3 &r2 = t2.c;
|
| 471 |
+
|
| 472 |
+
if (orient2D(r2, p2, q1) >= 0) { // I
|
| 473 |
+
if (orient2D(r2, p1, q1) >= 0) { // II.a
|
| 474 |
+
if (orient2D(p1, p2, q1) >= 0) { // III.a
|
| 475 |
+
return true;
|
| 476 |
+
} else {
|
| 477 |
+
if (orient2D(p1, p2, r1) >= 0) { // IV.a
|
| 478 |
+
if (orient2D(q1, r1, p2) >= 0) { // V
|
| 479 |
+
return true;
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
} else {
|
| 485 |
+
if (orient2D(r2, p2, r1) >= 0) { // II.b
|
| 486 |
+
if (orient2D(q1, r1, r2) >= 0) { // III.b
|
| 487 |
+
if (orient2D(p1, p2, r1) >= 0) { // IV.b (diverges from paper)
|
| 488 |
+
return true;
|
| 489 |
+
}
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
return false;
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
bool intersectionTypeR2(const Triangle &t1, const Triangle &t2) {
|
| 498 |
+
const uv_float3 &p1 = t1.a;
|
| 499 |
+
const uv_float3 &q1 = t1.b;
|
| 500 |
+
const uv_float3 &r1 = t1.c;
|
| 501 |
+
const uv_float3 &p2 = t2.a;
|
| 502 |
+
const uv_float3 &q2 = t2.b;
|
| 503 |
+
const uv_float3 &r2 = t2.c;
|
| 504 |
+
|
| 505 |
+
if (orient2D(r2, p2, q1) >= 0) { // I
|
| 506 |
+
if (orient2D(q2, r2, q1) >= 0) { // II.a
|
| 507 |
+
if (orient2D(p1, p2, q1) >= 0) { // III.a
|
| 508 |
+
if (orient2D(p1, q2, q1) <= 0) { // IV.a
|
| 509 |
+
return true;
|
| 510 |
+
}
|
| 511 |
+
} else {
|
| 512 |
+
if (orient2D(p1, p2, r1) >= 0) { // IV.b
|
| 513 |
+
if (orient2D(r2, p2, r1) <= 0) { // V.a
|
| 514 |
+
return true;
|
| 515 |
+
}
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
} else {
|
| 519 |
+
if (orient2D(p1, q2, q1) <= 0) { // III.b
|
| 520 |
+
if (orient2D(q2, r2, r1) >= 0) { // IV.c
|
| 521 |
+
if (orient2D(q1, r1, q2) >= 0) { // V.b
|
| 522 |
+
return true;
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
}
|
| 526 |
+
}
|
| 527 |
+
} else {
|
| 528 |
+
if (orient2D(r2, p2, r1) >= 0) { // II.b
|
| 529 |
+
if (orient2D(q1, r1, r2) >= 0) { // III.c
|
| 530 |
+
if (orient2D(r1, p1, p2) >= 0) { // IV.d
|
| 531 |
+
return true;
|
| 532 |
+
}
|
| 533 |
+
} else {
|
| 534 |
+
if (orient2D(q1, r1, q2) >= 0) { // IV.e
|
| 535 |
+
if (orient2D(q2, r2, r1) >= 0) { // V.c
|
| 536 |
+
return true;
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
}
|
| 540 |
+
}
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
return false;
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
bool coplanarIntersect(Triangle &t1, Triangle &t2,
|
| 547 |
+
std::vector<uv_float3> *target = nullptr) {
|
| 548 |
+
uv_float3 normal, u, v;
|
| 549 |
+
t1.getNormal(normal);
|
| 550 |
+
normal = normalize(normal);
|
| 551 |
+
u = normalize(t1.a - t1.b);
|
| 552 |
+
v = cross(normal, u);
|
| 553 |
+
|
| 554 |
+
// Move basis to t1.a
|
| 555 |
+
u = u + t1.a;
|
| 556 |
+
v = v + t1.a;
|
| 557 |
+
normal = normal + t1.a;
|
| 558 |
+
|
| 559 |
+
Matrix4 _matrix;
|
| 560 |
+
_matrix.set(t1.a.x, u.x, v.x, normal.x, t1.a.y, u.y, v.y, normal.y, t1.a.z,
|
| 561 |
+
u.z, v.z, normal.z, 1, 1, 1, 1);
|
| 562 |
+
|
| 563 |
+
Matrix4 _affineMatrix;
|
| 564 |
+
_affineMatrix.set(0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1);
|
| 565 |
+
|
| 566 |
+
_matrix.invert(); // Invert the _matrix
|
| 567 |
+
_matrix = _affineMatrix * _matrix;
|
| 568 |
+
|
| 569 |
+
// Apply transformation
|
| 570 |
+
apply_matrix4(t1.a, _matrix);
|
| 571 |
+
apply_matrix4(t1.b, _matrix);
|
| 572 |
+
apply_matrix4(t1.c, _matrix);
|
| 573 |
+
apply_matrix4(t2.a, _matrix);
|
| 574 |
+
apply_matrix4(t2.b, _matrix);
|
| 575 |
+
apply_matrix4(t2.c, _matrix);
|
| 576 |
+
|
| 577 |
+
makeTriCounterClockwise(t1);
|
| 578 |
+
makeTriCounterClockwise(t2);
|
| 579 |
+
|
| 580 |
+
const uv_float3 &p1 = t1.a;
|
| 581 |
+
const uv_float3 &p2 = t2.a;
|
| 582 |
+
const uv_float3 &q2 = t2.b;
|
| 583 |
+
const uv_float3 &r2 = t2.c;
|
| 584 |
+
|
| 585 |
+
int o_p2q2 = orient2D(p2, q2, p1);
|
| 586 |
+
int o_q2r2 = orient2D(q2, r2, p1);
|
| 587 |
+
int o_r2p2 = orient2D(r2, p2, p1);
|
| 588 |
+
|
| 589 |
+
bool intersecting = false;
|
| 590 |
+
if (o_p2q2 >= 0) {
|
| 591 |
+
if (o_q2r2 >= 0) {
|
| 592 |
+
if (o_r2p2 >= 0) {
|
| 593 |
+
// + + +
|
| 594 |
+
intersecting = true;
|
| 595 |
+
} else {
|
| 596 |
+
// + + -
|
| 597 |
+
intersecting = intersectionTypeR1(t1, t2);
|
| 598 |
+
}
|
| 599 |
+
} else {
|
| 600 |
+
if (o_r2p2 >= 0) {
|
| 601 |
+
// + - +
|
| 602 |
+
permuteTriRight(t2);
|
| 603 |
+
intersecting = intersectionTypeR1(t1, t2);
|
| 604 |
+
} else {
|
| 605 |
+
// + - -
|
| 606 |
+
intersecting = intersectionTypeR2(t1, t2);
|
| 607 |
+
}
|
| 608 |
+
}
|
| 609 |
+
} else {
|
| 610 |
+
if (o_q2r2 >= 0) {
|
| 611 |
+
if (o_r2p2 >= 0) {
|
| 612 |
+
// - + +
|
| 613 |
+
permuteTriLeft(t2);
|
| 614 |
+
intersecting = intersectionTypeR1(t1, t2);
|
| 615 |
+
} else {
|
| 616 |
+
// - + -
|
| 617 |
+
permuteTriLeft(t2);
|
| 618 |
+
intersecting = intersectionTypeR2(t1, t2);
|
| 619 |
+
}
|
| 620 |
+
} else {
|
| 621 |
+
if (o_r2p2 >= 0) {
|
| 622 |
+
// - - +
|
| 623 |
+
permuteTriRight(t2);
|
| 624 |
+
intersecting = intersectionTypeR2(t1, t2);
|
| 625 |
+
} else {
|
| 626 |
+
// - - -
|
| 627 |
+
std::cerr << "Triangles should not be flat." << std::endl;
|
| 628 |
+
return false;
|
| 629 |
+
}
|
| 630 |
+
}
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
if (intersecting && target) {
|
| 634 |
+
clipTriangle(t1, t2, *target);
|
| 635 |
+
|
| 636 |
+
_matrix.invert();
|
| 637 |
+
// Apply the transform to each target point
|
| 638 |
+
for (int i = 0; i < target->size(); ++i) {
|
| 639 |
+
apply_matrix4(target->at(i), _matrix);
|
| 640 |
+
}
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
return intersecting;
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
// Helper function to calculate the area of a polygon
|
| 647 |
+
float polygon_area(const std::vector<uv_float3> &polygon) {
|
| 648 |
+
if (polygon.size() < 3)
|
| 649 |
+
return 0.0f; // Not a polygon
|
| 650 |
+
|
| 651 |
+
uv_float3 normal = {0.0f, 0.0f, 0.0f}; // Initialize normal vector
|
| 652 |
+
|
| 653 |
+
// Calculate the cross product of edges around the polygon
|
| 654 |
+
for (size_t i = 0; i < polygon.size(); ++i) {
|
| 655 |
+
uv_float3 p1 = polygon[i];
|
| 656 |
+
uv_float3 p2 = polygon[(i + 1) % polygon.size()];
|
| 657 |
+
|
| 658 |
+
normal = normal + cross(p1, p2); // Accumulate the normal vector
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
float area =
|
| 662 |
+
magnitude(normal) / 2.0f; // Area is half the magnitude of the normal
|
| 663 |
+
return area;
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
|
| 667 |
+
uv_float2 p2, uv_float2 q2, uv_float2 r2) {
|
| 668 |
+
Triangle t1(p1, q1, r1);
|
| 669 |
+
Triangle t2(p2, q2, r2);
|
| 670 |
+
|
| 671 |
+
if (isTriDegenerated(t1) || isTriDegenerated(t2)) {
|
| 672 |
+
// std::cerr << "Degenerated triangles provided, skipping." << std::endl;
|
| 673 |
+
return false;
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
int o1a = orient3D(t2.a, t2.b, t2.c, t1.a);
|
| 677 |
+
int o1b = orient3D(t2.a, t2.b, t2.c, t1.b);
|
| 678 |
+
int o1c = orient3D(t2.a, t2.b, t2.c, t1.c);
|
| 679 |
+
|
| 680 |
+
std::vector<uv_float3> intersections;
|
| 681 |
+
bool intersects;
|
| 682 |
+
|
| 683 |
+
if (o1a == o1b && o1a == o1c) // [[likely]]
|
| 684 |
+
{
|
| 685 |
+
intersects = o1a == 0 && coplanarIntersect(t1, t2, &intersections);
|
| 686 |
+
} else // [[unlikely]]
|
| 687 |
+
{
|
| 688 |
+
intersects = crossIntersect(t1, t2, o1a, o1b, o1c, &intersections);
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
if (intersects) {
|
| 692 |
+
float area = polygon_area(intersections);
|
| 693 |
+
|
| 694 |
+
// std::cout << "Intersection area: " << area << std::endl;
|
| 695 |
+
if (area < 1e-10f || std::isfinite(area) == false) {
|
| 696 |
+
// std::cout<<"Invalid area: " << area << std::endl;
|
| 697 |
+
return false; // Ignore intersection if the area is too small
|
| 698 |
+
}
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
return intersects;
|
| 702 |
+
}
|
uv_unwrapper/uv_unwrapper/csrc/intersect.h
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "common.h"
|
| 4 |
+
#include <vector>
|
| 5 |
+
|
| 6 |
+
bool triangle_aabb_intersection(const uv_float2 &aabb_min,
|
| 7 |
+
const uv_float2 &aabb_max, const uv_float2 &v0,
|
| 8 |
+
const uv_float2 &v1, const uv_float2 &v2);
|
| 9 |
+
bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
|
| 10 |
+
uv_float2 p2, uv_float2 q2, uv_float2 r2);
|
uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "bvh.h"
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <ATen/Context.h>
|
| 4 |
+
#include <chrono>
|
| 5 |
+
#include <cmath>
|
| 6 |
+
#include <cstring>
|
| 7 |
+
#include <omp.h>
|
| 8 |
+
#include <set>
|
| 9 |
+
#include <torch/extension.h>
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
// #define TIMING
|
| 13 |
+
|
| 14 |
+
#if defined(_MSC_VER)
|
| 15 |
+
#include <BaseTsd.h>
|
| 16 |
+
typedef SSIZE_T ssize_t;
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
namespace UVUnwrapper {
|
| 20 |
+
void create_bvhs(BVH *bvhs, Triangle *triangles,
|
| 21 |
+
std::vector<std::set<int>> &triangle_per_face, int num_faces,
|
| 22 |
+
int start, int end) {
|
| 23 |
+
#pragma omp parallel for
|
| 24 |
+
for (int i = start; i < end; i++) {
|
| 25 |
+
int num_triangles = triangle_per_face[i].size();
|
| 26 |
+
Triangle *triangles_per_face = new Triangle[num_triangles];
|
| 27 |
+
int *indices = new int[num_triangles];
|
| 28 |
+
int j = 0;
|
| 29 |
+
for (int idx : triangle_per_face[i]) {
|
| 30 |
+
triangles_per_face[j] = triangles[idx];
|
| 31 |
+
indices[j++] = idx;
|
| 32 |
+
}
|
| 33 |
+
// Each thread writes to it's own memory space
|
| 34 |
+
// First check if the number of triangles is 0
|
| 35 |
+
if (num_triangles == 0) {
|
| 36 |
+
bvhs[i - start] = std::move(BVH()); // Default constructor
|
| 37 |
+
} else {
|
| 38 |
+
bvhs[i - start] = std::move(
|
| 39 |
+
BVH(triangles_per_face, indices,
|
| 40 |
+
num_triangles)); // BVH now handles memory of triangles_per_face
|
| 41 |
+
}
|
| 42 |
+
delete[] triangles_per_face;
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
void perform_intersection_check(BVH *bvhs, int num_bvhs, Triangle *triangles,
|
| 47 |
+
uv_float3 *vertex_tri_centroids,
|
| 48 |
+
int64_t *assign_indices_ptr,
|
| 49 |
+
ssize_t num_indices, int offset,
|
| 50 |
+
std::vector<std::set<int>> &triangle_per_face) {
|
| 51 |
+
std::vector<std::pair<int, int>>
|
| 52 |
+
unique_intersections; // Store unique intersections as pairs of triangle
|
| 53 |
+
// indices
|
| 54 |
+
|
| 55 |
+
// Step 1: Detect intersections in parallel
|
| 56 |
+
#pragma omp parallel for
|
| 57 |
+
for (int i = 0; i < num_indices; i++) {
|
| 58 |
+
if (assign_indices_ptr[i] < offset) {
|
| 59 |
+
continue;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
Triangle cur_tri = triangles[i];
|
| 63 |
+
auto &cur_bvh = bvhs[assign_indices_ptr[i] - offset];
|
| 64 |
+
|
| 65 |
+
if (cur_bvh.bvhNode == nullptr) {
|
| 66 |
+
continue;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
std::vector<int> intersections = cur_bvh.Intersect(cur_tri);
|
| 70 |
+
|
| 71 |
+
if (!intersections.empty()) {
|
| 72 |
+
|
| 73 |
+
#pragma omp critical
|
| 74 |
+
{
|
| 75 |
+
for (int intersect : intersections) {
|
| 76 |
+
if (i != intersect) {
|
| 77 |
+
// Ensure we only store unique pairs (A, B) where A < B to avoid
|
| 78 |
+
// duplication
|
| 79 |
+
if (i < intersect) {
|
| 80 |
+
unique_intersections.push_back(std::make_pair(i, intersect));
|
| 81 |
+
} else {
|
| 82 |
+
unique_intersections.push_back(std::make_pair(intersect, i));
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// Step 2: Process unique intersections
|
| 91 |
+
for (int idx = 0; idx < unique_intersections.size(); idx++) {
|
| 92 |
+
int first = unique_intersections[idx].first;
|
| 93 |
+
int second = unique_intersections[idx].second;
|
| 94 |
+
|
| 95 |
+
int i_idx = assign_indices_ptr[first];
|
| 96 |
+
|
| 97 |
+
int norm_idx = i_idx % 6;
|
| 98 |
+
int axis = (norm_idx < 2) ? 0 : (norm_idx < 4) ? 1 : 2;
|
| 99 |
+
bool use_max = (i_idx % 2) == 1;
|
| 100 |
+
|
| 101 |
+
float pos_a = vertex_tri_centroids[first][axis];
|
| 102 |
+
float pos_b = vertex_tri_centroids[second][axis];
|
| 103 |
+
// Sort the intersections based on vertex_tri_centroids along the specified
|
| 104 |
+
// axis
|
| 105 |
+
if (use_max) {
|
| 106 |
+
if (pos_a < pos_b) {
|
| 107 |
+
std::swap(first, second);
|
| 108 |
+
}
|
| 109 |
+
} else {
|
| 110 |
+
if (pos_a > pos_b) {
|
| 111 |
+
std::swap(first, second);
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// Update the unique intersections
|
| 116 |
+
unique_intersections[idx].first = first;
|
| 117 |
+
unique_intersections[idx].second = second;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// Now only get the second intersections from the pair and put them in a set
|
| 121 |
+
// The second intersection should always be the occluded triangle
|
| 122 |
+
std::set<int> second_intersections;
|
| 123 |
+
for (int idx = 0; idx < (int)unique_intersections.size(); idx++) {
|
| 124 |
+
int second = unique_intersections[idx].second;
|
| 125 |
+
second_intersections.insert(second);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
for (int int_idx : second_intersections) {
|
| 129 |
+
// Move the second (occluded) triangle by 6
|
| 130 |
+
int intersect_idx = assign_indices_ptr[int_idx];
|
| 131 |
+
int new_index = intersect_idx + 6;
|
| 132 |
+
new_index = std::clamp(new_index, 0, 12);
|
| 133 |
+
|
| 134 |
+
assign_indices_ptr[int_idx] = new_index;
|
| 135 |
+
triangle_per_face[intersect_idx].erase(int_idx);
|
| 136 |
+
triangle_per_face[new_index].insert(int_idx);
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
torch::Tensor assign_faces_uv_to_atlas_index(torch::Tensor vertices,
|
| 141 |
+
torch::Tensor indices,
|
| 142 |
+
torch::Tensor face_uv,
|
| 143 |
+
torch::Tensor face_index) {
|
| 144 |
+
// Get the number of faces
|
| 145 |
+
int num_faces = indices.size(0);
|
| 146 |
+
torch::Tensor assign_indices =
|
| 147 |
+
torch::empty(
|
| 148 |
+
{
|
| 149 |
+
num_faces,
|
| 150 |
+
},
|
| 151 |
+
torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU))
|
| 152 |
+
.contiguous();
|
| 153 |
+
|
| 154 |
+
auto vert_accessor = vertices.accessor<float, 2>();
|
| 155 |
+
auto indices_accessor = indices.accessor<int64_t, 2>();
|
| 156 |
+
auto face_uv_accessor = face_uv.accessor<float, 2>();
|
| 157 |
+
|
| 158 |
+
const int64_t *face_index_ptr = face_index.contiguous().data_ptr<int64_t>();
|
| 159 |
+
int64_t *assign_indices_ptr = assign_indices.data_ptr<int64_t>();
|
| 160 |
+
// copy face_index to assign_indices
|
| 161 |
+
memcpy(assign_indices_ptr, face_index_ptr, num_faces * sizeof(int64_t));
|
| 162 |
+
|
| 163 |
+
#ifdef TIMING
|
| 164 |
+
auto start = std::chrono::high_resolution_clock::now();
|
| 165 |
+
#endif
|
| 166 |
+
uv_float3 *vertex_tri_centroids = new uv_float3[num_faces];
|
| 167 |
+
Triangle *triangles = new Triangle[num_faces];
|
| 168 |
+
|
| 169 |
+
// Use std::set to store triangles for each face
|
| 170 |
+
std::vector<std::set<int>> triangle_per_face;
|
| 171 |
+
triangle_per_face.resize(13);
|
| 172 |
+
|
| 173 |
+
#pragma omp parallel for
|
| 174 |
+
for (int i = 0; i < num_faces; i++) {
|
| 175 |
+
int face_idx = i * 3;
|
| 176 |
+
triangles[i].v0 = {face_uv_accessor[face_idx + 0][0],
|
| 177 |
+
face_uv_accessor[face_idx + 0][1]};
|
| 178 |
+
triangles[i].v1 = {face_uv_accessor[face_idx + 1][0],
|
| 179 |
+
face_uv_accessor[face_idx + 1][1]};
|
| 180 |
+
triangles[i].v2 = {face_uv_accessor[face_idx + 2][0],
|
| 181 |
+
face_uv_accessor[face_idx + 2][1]};
|
| 182 |
+
triangles[i].centroid =
|
| 183 |
+
triangle_centroid(triangles[i].v0, triangles[i].v1, triangles[i].v2);
|
| 184 |
+
|
| 185 |
+
uv_float3 v0 = {vert_accessor[indices_accessor[i][0]][0],
|
| 186 |
+
vert_accessor[indices_accessor[i][0]][1],
|
| 187 |
+
vert_accessor[indices_accessor[i][0]][2]};
|
| 188 |
+
uv_float3 v1 = {vert_accessor[indices_accessor[i][1]][0],
|
| 189 |
+
vert_accessor[indices_accessor[i][1]][1],
|
| 190 |
+
vert_accessor[indices_accessor[i][1]][2]};
|
| 191 |
+
uv_float3 v2 = {vert_accessor[indices_accessor[i][2]][0],
|
| 192 |
+
vert_accessor[indices_accessor[i][2]][1],
|
| 193 |
+
vert_accessor[indices_accessor[i][2]][2]};
|
| 194 |
+
vertex_tri_centroids[i] = triangle_centroid(v0, v1, v2);
|
| 195 |
+
|
| 196 |
+
// Assign the triangle to the face index
|
| 197 |
+
#pragma omp critical
|
| 198 |
+
{ triangle_per_face[face_index_ptr[i]].insert(i); }
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
#ifdef TIMING
|
| 202 |
+
auto start_bvh = std::chrono::high_resolution_clock::now();
|
| 203 |
+
#endif
|
| 204 |
+
|
| 205 |
+
BVH *bvhs = new BVH[6];
|
| 206 |
+
create_bvhs(bvhs, triangles, triangle_per_face, num_faces, 0, 6);
|
| 207 |
+
|
| 208 |
+
#ifdef TIMING
|
| 209 |
+
auto end_bvh = std::chrono::high_resolution_clock::now();
|
| 210 |
+
std::chrono::duration<double> elapsed_seconds = end_bvh - start_bvh;
|
| 211 |
+
std::cout << "BVH build time: " << elapsed_seconds.count() << "s\n";
|
| 212 |
+
|
| 213 |
+
auto start_intersection_1 = std::chrono::high_resolution_clock::now();
|
| 214 |
+
#endif
|
| 215 |
+
|
| 216 |
+
perform_intersection_check(bvhs, 6, triangles, vertex_tri_centroids,
|
| 217 |
+
assign_indices_ptr, num_faces, 0,
|
| 218 |
+
triangle_per_face);
|
| 219 |
+
|
| 220 |
+
#ifdef TIMING
|
| 221 |
+
auto end_intersection_1 = std::chrono::high_resolution_clock::now();
|
| 222 |
+
elapsed_seconds = end_intersection_1 - start_intersection_1;
|
| 223 |
+
std::cout << "Intersection 1 time: " << elapsed_seconds.count() << "s\n";
|
| 224 |
+
#endif
|
| 225 |
+
// Create 6 new bvhs and delete the old ones
|
| 226 |
+
BVH *new_bvhs = new BVH[6];
|
| 227 |
+
create_bvhs(new_bvhs, triangles, triangle_per_face, num_faces, 6, 12);
|
| 228 |
+
|
| 229 |
+
#ifdef TIMING
|
| 230 |
+
auto end_bvh2 = std::chrono::high_resolution_clock::now();
|
| 231 |
+
elapsed_seconds = end_bvh2 - end_intersection_1;
|
| 232 |
+
std::cout << "BVH 2 build time: " << elapsed_seconds.count() << "s\n";
|
| 233 |
+
auto start_intersection_2 = std::chrono::high_resolution_clock::now();
|
| 234 |
+
#endif
|
| 235 |
+
|
| 236 |
+
perform_intersection_check(new_bvhs, 6, triangles, vertex_tri_centroids,
|
| 237 |
+
assign_indices_ptr, num_faces, 6,
|
| 238 |
+
triangle_per_face);
|
| 239 |
+
|
| 240 |
+
#ifdef TIMING
|
| 241 |
+
auto end_intersection_2 = std::chrono::high_resolution_clock::now();
|
| 242 |
+
elapsed_seconds = end_intersection_2 - start_intersection_2;
|
| 243 |
+
std::cout << "Intersection 2 time: " << elapsed_seconds.count() << "s\n";
|
| 244 |
+
elapsed_seconds = end_intersection_2 - start;
|
| 245 |
+
std::cout << "Total time: " << elapsed_seconds.count() << "s\n";
|
| 246 |
+
#endif
|
| 247 |
+
|
| 248 |
+
// Cleanup
|
| 249 |
+
delete[] vertex_tri_centroids;
|
| 250 |
+
delete[] triangles;
|
| 251 |
+
delete[] bvhs;
|
| 252 |
+
delete[] new_bvhs;
|
| 253 |
+
|
| 254 |
+
return assign_indices;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// Registers _C as a Python extension module.
|
| 258 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
|
| 259 |
+
|
| 260 |
+
// Defines the operators
|
| 261 |
+
TORCH_LIBRARY(UVUnwrapper, m) {
|
| 262 |
+
m.def("assign_faces_uv_to_atlas_index(Tensor vertices, Tensor indices, "
|
| 263 |
+
"Tensor face_uv, Tensor face_index) -> Tensor");
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
// Registers CPP implementations
|
| 267 |
+
TORCH_LIBRARY_IMPL(UVUnwrapper, CPU, m) {
|
| 268 |
+
m.impl("assign_faces_uv_to_atlas_index", &assign_faces_uv_to_atlas_index);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
} // namespace UVUnwrapper
|
uv_unwrapper/uv_unwrapper/unwrap.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Unwrapper(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
def _box_assign_vertex_to_cube_face(
|
| 15 |
+
self,
|
| 16 |
+
vertex_positions: Tensor,
|
| 17 |
+
vertex_normals: Tensor,
|
| 18 |
+
triangle_idxs: Tensor,
|
| 19 |
+
bbox: Tensor,
|
| 20 |
+
) -> Tuple[Tensor, Tensor]:
|
| 21 |
+
"""
|
| 22 |
+
Assigns each vertex to a cube face based on the face normal
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
vertex_positions (Tensor, Nv 3, float): Vertex positions
|
| 26 |
+
vertex_normals (Tensor, Nv 3, float): Vertex normals
|
| 27 |
+
triangle_idxs (Tensor, Nf 3, int): Triangle indices
|
| 28 |
+
bbox (Tensor, 2 3, float): Bounding box of the mesh
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Tensor, Nf 3 2, float: UV coordinates
|
| 32 |
+
Tensor, Nf, int: Cube face indices
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# Test to not have a scaled model to fit the space better
|
| 36 |
+
# bbox_min = bbox[:1].mean(-1, keepdim=True)
|
| 37 |
+
# bbox_max = bbox[1:].mean(-1, keepdim=True)
|
| 38 |
+
# v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
|
| 39 |
+
|
| 40 |
+
# Create a [0, 1] normalized vertex position
|
| 41 |
+
v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
|
| 42 |
+
# And to [-1, 1]
|
| 43 |
+
v_pos_normalized = 2.0 * v_pos_normalized - 1.0
|
| 44 |
+
|
| 45 |
+
# Get all vertex positions for each triangle
|
| 46 |
+
# Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
|
| 47 |
+
v0 = v_pos_normalized[triangle_idxs[:, 0]]
|
| 48 |
+
v1 = v_pos_normalized[triangle_idxs[:, 1]]
|
| 49 |
+
v2 = v_pos_normalized[triangle_idxs[:, 2]]
|
| 50 |
+
tri_stack = torch.stack([v0, v1, v2], dim=1)
|
| 51 |
+
|
| 52 |
+
vn0 = vertex_normals[triangle_idxs[:, 0]]
|
| 53 |
+
vn1 = vertex_normals[triangle_idxs[:, 1]]
|
| 54 |
+
vn2 = vertex_normals[triangle_idxs[:, 2]]
|
| 55 |
+
tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
|
| 56 |
+
|
| 57 |
+
# Just average the normals per face
|
| 58 |
+
face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
|
| 59 |
+
|
| 60 |
+
# Now decide based on the face normal in which box map we project
|
| 61 |
+
# abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
|
| 62 |
+
abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
|
| 63 |
+
|
| 64 |
+
axis = torch.tensor(
|
| 65 |
+
[
|
| 66 |
+
[1, 0, 0], # 0
|
| 67 |
+
[-1, 0, 0], # 1
|
| 68 |
+
[0, 1, 0], # 2
|
| 69 |
+
[0, -1, 0], # 3
|
| 70 |
+
[0, 0, 1], # 4
|
| 71 |
+
[0, 0, -1], # 5
|
| 72 |
+
],
|
| 73 |
+
device=face_normal.device,
|
| 74 |
+
dtype=face_normal.dtype,
|
| 75 |
+
)
|
| 76 |
+
face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
|
| 77 |
+
index = face_normal_axis.argmax(-1)
|
| 78 |
+
|
| 79 |
+
max_axis, uc, vc = (
|
| 80 |
+
torch.ones_like(abs_x),
|
| 81 |
+
torch.zeros_like(tri_stack[..., :1]),
|
| 82 |
+
torch.zeros_like(tri_stack[..., :1]),
|
| 83 |
+
)
|
| 84 |
+
mask_pos_x = index == 0
|
| 85 |
+
max_axis[mask_pos_x] = abs_x[mask_pos_x]
|
| 86 |
+
uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
|
| 87 |
+
vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
|
| 88 |
+
|
| 89 |
+
mask_neg_x = index == 1
|
| 90 |
+
max_axis[mask_neg_x] = abs_x[mask_neg_x]
|
| 91 |
+
uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
|
| 92 |
+
vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
|
| 93 |
+
|
| 94 |
+
mask_pos_y = index == 2
|
| 95 |
+
max_axis[mask_pos_y] = abs_y[mask_pos_y]
|
| 96 |
+
uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
|
| 97 |
+
vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
|
| 98 |
+
|
| 99 |
+
mask_neg_y = index == 3
|
| 100 |
+
max_axis[mask_neg_y] = abs_y[mask_neg_y]
|
| 101 |
+
uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
|
| 102 |
+
vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
|
| 103 |
+
|
| 104 |
+
mask_pos_z = index == 4
|
| 105 |
+
max_axis[mask_pos_z] = abs_z[mask_pos_z]
|
| 106 |
+
uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
|
| 107 |
+
vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
|
| 108 |
+
|
| 109 |
+
mask_neg_z = index == 5
|
| 110 |
+
max_axis[mask_neg_z] = abs_z[mask_neg_z]
|
| 111 |
+
uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
|
| 112 |
+
vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
|
| 113 |
+
|
| 114 |
+
# UC from [-1, 1] to [0, 1]
|
| 115 |
+
max_dim_div = max_axis.max(dim=0, keepdim=True).values
|
| 116 |
+
uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
| 117 |
+
vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
| 118 |
+
|
| 119 |
+
uv = torch.stack([uc, vc], dim=-1)
|
| 120 |
+
|
| 121 |
+
return uv, index
|
| 122 |
+
|
| 123 |
+
def _assign_faces_uv_to_atlas_index(
|
| 124 |
+
self,
|
| 125 |
+
vertex_positions: Tensor,
|
| 126 |
+
triangle_idxs: Tensor,
|
| 127 |
+
face_uv: Tensor,
|
| 128 |
+
face_index: Tensor,
|
| 129 |
+
) -> Tensor: # noqa: F821
|
| 130 |
+
"""
|
| 131 |
+
Assigns the face UV to the atlas index
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
| 135 |
+
triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
|
| 136 |
+
face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
|
| 137 |
+
face_index (Integer[Tensor, "Nf"]): Face indices
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Integer[Tensor, "Nf"]: Atlas index
|
| 141 |
+
"""
|
| 142 |
+
return torch.ops.UVUnwrapper.assign_faces_uv_to_atlas_index(
|
| 143 |
+
vertex_positions.cpu(),
|
| 144 |
+
triangle_idxs.cpu(),
|
| 145 |
+
face_uv.view(-1, 2).cpu(),
|
| 146 |
+
face_index.cpu(),
|
| 147 |
+
).to(vertex_positions.device)
|
| 148 |
+
|
| 149 |
+
def _find_slice_offset_and_scale(
|
| 150 |
+
self, index: Tensor
|
| 151 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # noqa: F821
|
| 152 |
+
"""
|
| 153 |
+
Find the slice offset and scale
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
index (Integer[Tensor, "Nf"]): Atlas index
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Float[Tensor, "Nf"]: Offset x
|
| 160 |
+
Float[Tensor, "Nf"]: Offset y
|
| 161 |
+
Float[Tensor, "Nf"]: Division x
|
| 162 |
+
Float[Tensor, "Nf"]: Division y
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
# 6 due to the 6 cube faces
|
| 166 |
+
off = 1 / 3
|
| 167 |
+
dupl_off = 1 / 6
|
| 168 |
+
|
| 169 |
+
# Here, we need to decide how to pack the textures in the case of overlap
|
| 170 |
+
def x_offset_calc(x, i):
|
| 171 |
+
offset_calc = i // 6
|
| 172 |
+
# Initial coordinates - just 3x2 grid
|
| 173 |
+
if offset_calc == 0:
|
| 174 |
+
return off * x
|
| 175 |
+
else:
|
| 176 |
+
# Smaller 3x2 grid plus eventual shift to right for
|
| 177 |
+
# second overlap
|
| 178 |
+
return dupl_off * x + min(offset_calc - 1, 1) * 0.5
|
| 179 |
+
|
| 180 |
+
def y_offset_calc(x, i):
|
| 181 |
+
offset_calc = i // 6
|
| 182 |
+
# Initial coordinates - just a 3x2 grid
|
| 183 |
+
if offset_calc == 0:
|
| 184 |
+
return off * x
|
| 185 |
+
else:
|
| 186 |
+
# Smaller coordinates in the lowest row
|
| 187 |
+
return dupl_off * x + off * 2
|
| 188 |
+
|
| 189 |
+
offset_x = torch.zeros_like(index, dtype=torch.float32)
|
| 190 |
+
offset_y = torch.zeros_like(index, dtype=torch.float32)
|
| 191 |
+
offset_x_vals = [0, 1, 2, 0, 1, 2]
|
| 192 |
+
offset_y_vals = [0, 0, 0, 1, 1, 1]
|
| 193 |
+
for i in range(index.max().item() + 1):
|
| 194 |
+
mask = index == i
|
| 195 |
+
if not mask.any():
|
| 196 |
+
continue
|
| 197 |
+
offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
|
| 198 |
+
offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
|
| 199 |
+
|
| 200 |
+
div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
|
| 201 |
+
# All overlap elements are saved in half scale
|
| 202 |
+
div_x[index >= 6] = 6
|
| 203 |
+
div_y = div_x.clone() # Same for y
|
| 204 |
+
# Except for the random overlaps
|
| 205 |
+
div_x[index >= 12] = 2
|
| 206 |
+
# But the random overlaps are saved in a large block in the lower thirds
|
| 207 |
+
div_y[index >= 12] = 3
|
| 208 |
+
|
| 209 |
+
return offset_x, offset_y, div_x, div_y
|
| 210 |
+
|
| 211 |
+
def _calculate_tangents(
|
| 212 |
+
self,
|
| 213 |
+
vertex_positions: Tensor,
|
| 214 |
+
vertex_normals: Tensor,
|
| 215 |
+
triangle_idxs: Tensor,
|
| 216 |
+
face_uv: Tensor,
|
| 217 |
+
) -> Tensor:
|
| 218 |
+
"""
|
| 219 |
+
Calculate the tangents for each triangle
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
| 223 |
+
vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
|
| 224 |
+
triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
|
| 225 |
+
face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Float[Tensor, "Nf 3 4"]: Tangents
|
| 229 |
+
"""
|
| 230 |
+
vn_idx = [None] * 3
|
| 231 |
+
pos = [None] * 3
|
| 232 |
+
tex = face_uv.unbind(1)
|
| 233 |
+
for i in range(0, 3):
|
| 234 |
+
pos[i] = vertex_positions[triangle_idxs[:, i]]
|
| 235 |
+
# t_nrm_idx is always the same as t_pos_idx
|
| 236 |
+
vn_idx[i] = triangle_idxs[:, i]
|
| 237 |
+
|
| 238 |
+
if torch.backends.mps.is_available():
|
| 239 |
+
tangents = torch.zeros_like(vertex_normals).contiguous()
|
| 240 |
+
tansum = torch.zeros_like(vertex_normals).contiguous()
|
| 241 |
+
else:
|
| 242 |
+
tangents = torch.zeros_like(vertex_normals)
|
| 243 |
+
tansum = torch.zeros_like(vertex_normals)
|
| 244 |
+
|
| 245 |
+
# Compute tangent space for each triangle
|
| 246 |
+
duv1 = tex[1] - tex[0]
|
| 247 |
+
duv2 = tex[2] - tex[0]
|
| 248 |
+
dpos1 = pos[1] - pos[0]
|
| 249 |
+
dpos2 = pos[2] - pos[0]
|
| 250 |
+
|
| 251 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
| 252 |
+
|
| 253 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
| 254 |
+
|
| 255 |
+
# Avoid division by zero for degenerated texture coordinates
|
| 256 |
+
denom_safe = denom.clip(1e-6)
|
| 257 |
+
tang = tng_nom / denom_safe
|
| 258 |
+
|
| 259 |
+
# Update all 3 vertices
|
| 260 |
+
for i in range(0, 3):
|
| 261 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
| 262 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
| 263 |
+
tansum.scatter_add_(
|
| 264 |
+
0, idx, torch.ones_like(tang)
|
| 265 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
| 266 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
| 267 |
+
# triangles influence the tangent space more
|
| 268 |
+
tangents = tangents / tansum
|
| 269 |
+
|
| 270 |
+
# Normalize and make sure tangent is perpendicular to normal
|
| 271 |
+
tangents = F.normalize(tangents, dim=1)
|
| 272 |
+
tangents = F.normalize(
|
| 273 |
+
tangents
|
| 274 |
+
- (tangents * vertex_normals).sum(-1, keepdim=True) * vertex_normals
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
return tangents
|
| 278 |
+
|
| 279 |
+
def _rotate_uv_slices_consistent_space(
|
| 280 |
+
self,
|
| 281 |
+
vertex_positions: Tensor,
|
| 282 |
+
vertex_normals: Tensor,
|
| 283 |
+
triangle_idxs: Tensor,
|
| 284 |
+
uv: Tensor,
|
| 285 |
+
index: Tensor,
|
| 286 |
+
) -> Tensor:
|
| 287 |
+
"""
|
| 288 |
+
Rotate the UV slices so they are in a consistent space
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
| 292 |
+
vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
|
| 293 |
+
triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
|
| 294 |
+
uv (Float[Tensor, "Nf 3 2"]): UV coordinates
|
| 295 |
+
index (Integer[Tensor, "Nf"]): Atlas index
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Float[Tensor, "Nf 3 2"]: Rotated UV coordinates
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
tangents = self._calculate_tangents(
|
| 302 |
+
vertex_positions, vertex_normals, triangle_idxs, uv
|
| 303 |
+
)
|
| 304 |
+
pos_stack = torch.stack(
|
| 305 |
+
[
|
| 306 |
+
-vertex_positions[..., 1],
|
| 307 |
+
vertex_positions[..., 0],
|
| 308 |
+
torch.zeros_like(vertex_positions[..., 0]),
|
| 309 |
+
],
|
| 310 |
+
dim=-1,
|
| 311 |
+
)
|
| 312 |
+
expected_tangents = F.normalize(
|
| 313 |
+
torch.linalg.cross(
|
| 314 |
+
vertex_normals,
|
| 315 |
+
torch.linalg.cross(pos_stack, vertex_normals, dim=-1),
|
| 316 |
+
dim=-1,
|
| 317 |
+
),
|
| 318 |
+
-1,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
actual_tangents = tangents[triangle_idxs]
|
| 322 |
+
expected_tangents = expected_tangents[triangle_idxs]
|
| 323 |
+
|
| 324 |
+
def rotation_matrix_2d(theta):
|
| 325 |
+
c, s = torch.cos(theta), torch.sin(theta)
|
| 326 |
+
return torch.tensor([[c, -s], [s, c]])
|
| 327 |
+
|
| 328 |
+
# Now find the rotation
|
| 329 |
+
index_mod = index % 6 # Shouldn't happen. Just for safety
|
| 330 |
+
for i in range(6):
|
| 331 |
+
mask = index_mod == i
|
| 332 |
+
if not mask.any():
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
|
| 336 |
+
expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
|
| 337 |
+
|
| 338 |
+
dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
|
| 339 |
+
cross_product = (
|
| 340 |
+
actual_mean_tangent[0] * expected_mean_tangent[1]
|
| 341 |
+
- actual_mean_tangent[1] * expected_mean_tangent[0]
|
| 342 |
+
)
|
| 343 |
+
angle = torch.atan2(cross_product, dot_product)
|
| 344 |
+
|
| 345 |
+
rot_matrix = rotation_matrix_2d(angle).to(mask.device)
|
| 346 |
+
# Center the uv coordinate to be in the range of -1 to 1 and 0 centered
|
| 347 |
+
uv_cur = uv[mask] * 2 - 1 # Center it first
|
| 348 |
+
# Rotate it
|
| 349 |
+
uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
|
| 350 |
+
|
| 351 |
+
# Rescale uv[mask] to be within the 0-1 range
|
| 352 |
+
uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
|
| 353 |
+
|
| 354 |
+
return uv
|
| 355 |
+
|
| 356 |
+
def _handle_slice_uvs(
|
| 357 |
+
self,
|
| 358 |
+
uv: Tensor,
|
| 359 |
+
index: Tensor, # noqa: F821
|
| 360 |
+
island_padding: float,
|
| 361 |
+
max_index: int = 6 * 2,
|
| 362 |
+
) -> Tensor: # noqa: F821
|
| 363 |
+
"""
|
| 364 |
+
Handle the slice UVs
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
uv (Float[Tensor, "Nf 3 2"]): UV coordinates
|
| 368 |
+
index (Integer[Tensor, "Nf"]): Atlas index
|
| 369 |
+
island_padding (float): Island padding
|
| 370 |
+
max_index (int): Maximum index
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
Float[Tensor, "Nf 3 2"]: Updated UV coordinates
|
| 374 |
+
|
| 375 |
+
"""
|
| 376 |
+
uc, vc = uv.unbind(-1)
|
| 377 |
+
|
| 378 |
+
# Get the second slice (The first overlap)
|
| 379 |
+
index_filter = [index == i for i in range(6, max_index)]
|
| 380 |
+
|
| 381 |
+
# Normalize them to always fully fill the atlas patch
|
| 382 |
+
for i, fi in enumerate(index_filter):
|
| 383 |
+
if fi.sum() > 0:
|
| 384 |
+
# Scale the slice but only up to a factor of 2
|
| 385 |
+
# This keeps the texture resolution with the first slice in line (Half space in UV)
|
| 386 |
+
uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(
|
| 387 |
+
0.5
|
| 388 |
+
)
|
| 389 |
+
vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(
|
| 390 |
+
0.5
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
| 394 |
+
vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
| 395 |
+
|
| 396 |
+
return torch.stack([uc_padded, vc_padded], dim=-1)
|
| 397 |
+
|
| 398 |
+
def _handle_remaining_uvs(
|
| 399 |
+
self,
|
| 400 |
+
uv: Tensor,
|
| 401 |
+
index: Tensor, # noqa: F821
|
| 402 |
+
island_padding: float,
|
| 403 |
+
) -> Tensor:
|
| 404 |
+
"""
|
| 405 |
+
Handle the remaining UVs (The ones that are not slices)
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
uv (Float[Tensor, "Nf 3 2"]): UV coordinates
|
| 409 |
+
index (Integer[Tensor, "Nf"]): Atlas index
|
| 410 |
+
island_padding (float): Island padding
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
Float[Tensor, "Nf 3 2"]: Updated UV coordinates
|
| 414 |
+
"""
|
| 415 |
+
uc, vc = uv.unbind(-1)
|
| 416 |
+
# Get all remaining elements
|
| 417 |
+
remaining_filter = index >= 6 * 2
|
| 418 |
+
squares_left = remaining_filter.sum()
|
| 419 |
+
|
| 420 |
+
if squares_left == 0:
|
| 421 |
+
return uv
|
| 422 |
+
|
| 423 |
+
uc = uc[remaining_filter]
|
| 424 |
+
vc = vc[remaining_filter]
|
| 425 |
+
|
| 426 |
+
# Or remaining triangles are distributed in a rectangle
|
| 427 |
+
# The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
|
| 428 |
+
ratio = 0.5 * (1 / 3) # 1.5
|
| 429 |
+
# sqrt(744/(0.5*(1/3)))
|
| 430 |
+
|
| 431 |
+
mult = math.sqrt(squares_left / ratio)
|
| 432 |
+
num_square_width = int(math.ceil(0.5 * mult))
|
| 433 |
+
num_square_height = int(math.ceil(squares_left / num_square_width))
|
| 434 |
+
|
| 435 |
+
width = 1 / num_square_width
|
| 436 |
+
height = 1 / num_square_height
|
| 437 |
+
|
| 438 |
+
# The idea is again to keep the texture resolution consistent with the first slice
|
| 439 |
+
# This only occupys half the region in the texture chart but the scaling on the squares
|
| 440 |
+
# assumes full coverage.
|
| 441 |
+
clip_val = min(width, height) * 1.5
|
| 442 |
+
# Now normalize the UVs with taking into account the maximum scaling
|
| 443 |
+
uc = (uc - uc.min(dim=1, keepdim=True).values) / (
|
| 444 |
+
uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
|
| 445 |
+
).clip(clip_val)
|
| 446 |
+
vc = (vc - vc.min(dim=1, keepdim=True).values) / (
|
| 447 |
+
vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
|
| 448 |
+
).clip(clip_val)
|
| 449 |
+
# Add a small padding
|
| 450 |
+
uc = (
|
| 451 |
+
uc * (1 - island_padding * num_square_width * 0.5)
|
| 452 |
+
+ island_padding * num_square_width * 0.25
|
| 453 |
+
).clip(0, 1)
|
| 454 |
+
vc = (
|
| 455 |
+
vc * (1 - island_padding * num_square_height * 0.5)
|
| 456 |
+
+ island_padding * num_square_height * 0.25
|
| 457 |
+
).clip(0, 1)
|
| 458 |
+
|
| 459 |
+
uc = uc * width
|
| 460 |
+
vc = vc * height
|
| 461 |
+
|
| 462 |
+
# And calculate offsets for each element
|
| 463 |
+
idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
|
| 464 |
+
x_idx = idx % num_square_width
|
| 465 |
+
y_idx = idx // num_square_width
|
| 466 |
+
# And move each triangle to its own spot
|
| 467 |
+
uc = uc + x_idx[:, None] * width
|
| 468 |
+
vc = vc + y_idx[:, None] * height
|
| 469 |
+
|
| 470 |
+
uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
| 471 |
+
vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
| 472 |
+
|
| 473 |
+
uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
|
| 474 |
+
|
| 475 |
+
return uv
|
| 476 |
+
|
| 477 |
+
def _distribute_individual_uvs_in_atlas(
|
| 478 |
+
self,
|
| 479 |
+
face_uv: Tensor,
|
| 480 |
+
assigned_faces: Tensor,
|
| 481 |
+
offset_x: Tensor,
|
| 482 |
+
offset_y: Tensor,
|
| 483 |
+
div_x: Tensor,
|
| 484 |
+
div_y: Tensor,
|
| 485 |
+
island_padding: float,
|
| 486 |
+
) -> Tensor:
|
| 487 |
+
"""
|
| 488 |
+
Distribute the individual UVs in the atlas
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
|
| 492 |
+
assigned_faces (Integer[Tensor, "Nf"]): Assigned faces
|
| 493 |
+
offset_x (Float[Tensor, "Nf"]): Offset x
|
| 494 |
+
offset_y (Float[Tensor, "Nf"]): Offset y
|
| 495 |
+
div_x (Float[Tensor, "Nf"]): Division x
|
| 496 |
+
div_y (Float[Tensor, "Nf"]): Division y
|
| 497 |
+
island_padding (float): Island padding
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
Float[Tensor, "Nf 3 2"]: Updated UV coordinates
|
| 501 |
+
"""
|
| 502 |
+
# Place the slice first
|
| 503 |
+
placed_uv = self._handle_slice_uvs(face_uv, assigned_faces, island_padding)
|
| 504 |
+
# Then handle the remaining overlap elements
|
| 505 |
+
placed_uv = self._handle_remaining_uvs(
|
| 506 |
+
placed_uv, assigned_faces, island_padding
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
uc, vc = placed_uv.unbind(-1)
|
| 510 |
+
uc = uc / div_x[:, None] + offset_x[:, None]
|
| 511 |
+
vc = vc / div_y[:, None] + offset_y[:, None]
|
| 512 |
+
|
| 513 |
+
uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
|
| 514 |
+
|
| 515 |
+
return uv
|
| 516 |
+
|
| 517 |
+
def _get_unique_face_uv(
|
| 518 |
+
self,
|
| 519 |
+
uv: Tensor,
|
| 520 |
+
) -> Tuple[Tensor, Tensor]:
|
| 521 |
+
"""
|
| 522 |
+
Get the unique face UV
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
uv (Float[Tensor, "Nf 3 2"]): UV coordinates
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
Float[Tensor, "Utex 3"]: Unique UV coordinates
|
| 529 |
+
Integer[Tensor, "Nf"]: Vertex index
|
| 530 |
+
"""
|
| 531 |
+
unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
|
| 532 |
+
# And add the face to uv index mapping
|
| 533 |
+
vtex_idx = unique_idx.view(-1, 3)
|
| 534 |
+
|
| 535 |
+
return unique_uv, vtex_idx
|
| 536 |
+
|
| 537 |
+
def _align_mesh_with_main_axis(
|
| 538 |
+
self, vertex_positions: Tensor, vertex_normals: Tensor
|
| 539 |
+
) -> Tuple[Tensor, Tensor]:
|
| 540 |
+
"""
|
| 541 |
+
Align the mesh with the main axis
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
| 545 |
+
vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
Float[Tensor, "Nv 3"]: Rotated vertex positions
|
| 549 |
+
Float[Tensor, "Nv 3"]: Rotated vertex normals
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
# Use pca to find the 2 main axis (third is derived by cross product)
|
| 553 |
+
# Set the random seed so it's repeatable
|
| 554 |
+
torch.manual_seed(0)
|
| 555 |
+
_, _, v = torch.pca_lowrank(vertex_positions, q=2)
|
| 556 |
+
main_axis, seconday_axis = v[:, 0], v[:, 1]
|
| 557 |
+
|
| 558 |
+
main_axis = F.normalize(main_axis, eps=1e-6, dim=-1) # 3,
|
| 559 |
+
# Orthogonalize the second axis
|
| 560 |
+
seconday_axis = F.normalize(
|
| 561 |
+
seconday_axis
|
| 562 |
+
- (seconday_axis * main_axis).sum(-1, keepdim=True) * main_axis,
|
| 563 |
+
eps=1e-6,
|
| 564 |
+
dim=-1,
|
| 565 |
+
) # 3,
|
| 566 |
+
# Create perpendicular third axis
|
| 567 |
+
third_axis = F.normalize(
|
| 568 |
+
torch.cross(main_axis, seconday_axis, dim=-1), dim=-1, eps=1e-6
|
| 569 |
+
) # 3,
|
| 570 |
+
|
| 571 |
+
# Check to which canonical axis each aligns
|
| 572 |
+
main_axis_max_idx = main_axis.abs().argmax().item()
|
| 573 |
+
seconday_axis_max_idx = seconday_axis.abs().argmax().item()
|
| 574 |
+
third_axis_max_idx = third_axis.abs().argmax().item()
|
| 575 |
+
|
| 576 |
+
# Now sort the axes based on the argmax so they align with thecanonoical axes
|
| 577 |
+
# If two axes have the same argmax move one of them
|
| 578 |
+
all_possible_axis = {0, 1, 2}
|
| 579 |
+
cur_index = 1
|
| 580 |
+
while (
|
| 581 |
+
len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]))
|
| 582 |
+
!= 3
|
| 583 |
+
):
|
| 584 |
+
# Find missing axis
|
| 585 |
+
missing_axis = all_possible_axis - set(
|
| 586 |
+
[main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
|
| 587 |
+
)
|
| 588 |
+
missing_axis = missing_axis.pop()
|
| 589 |
+
# Just assign it to third axis as it had the smallest contribution to the
|
| 590 |
+
# overall shape
|
| 591 |
+
if cur_index == 1:
|
| 592 |
+
third_axis_max_idx = missing_axis
|
| 593 |
+
elif cur_index == 2:
|
| 594 |
+
seconday_axis_max_idx = missing_axis
|
| 595 |
+
else:
|
| 596 |
+
raise ValueError("Could not find 3 unique axis")
|
| 597 |
+
cur_index += 1
|
| 598 |
+
|
| 599 |
+
if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
|
| 600 |
+
raise ValueError("Could not find 3 unique axis")
|
| 601 |
+
|
| 602 |
+
axes = [None] * 3
|
| 603 |
+
axes[main_axis_max_idx] = main_axis
|
| 604 |
+
axes[seconday_axis_max_idx] = seconday_axis
|
| 605 |
+
axes[third_axis_max_idx] = third_axis
|
| 606 |
+
# Create rotation matrix from the individual axes
|
| 607 |
+
rot_mat = torch.stack(axes, dim=1).T
|
| 608 |
+
|
| 609 |
+
# Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
|
| 610 |
+
vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
|
| 611 |
+
vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
|
| 612 |
+
|
| 613 |
+
return vertex_positions, vertex_normals
|
| 614 |
+
|
| 615 |
+
def forward(
|
| 616 |
+
self,
|
| 617 |
+
vertex_positions: Tensor,
|
| 618 |
+
vertex_normals: Tensor,
|
| 619 |
+
triangle_idxs: Tensor,
|
| 620 |
+
island_padding: float,
|
| 621 |
+
) -> Tuple[Tensor, Tensor]:
|
| 622 |
+
"""
|
| 623 |
+
Unwrap the mesh
|
| 624 |
+
|
| 625 |
+
Args:
|
| 626 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
| 627 |
+
vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
|
| 628 |
+
triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
|
| 629 |
+
island_padding (float): Island padding
|
| 630 |
+
|
| 631 |
+
Returns:
|
| 632 |
+
Float[Tensor, "Utex 3"]: Unique UV coordinates
|
| 633 |
+
Integer[Tensor, "Nf"]: Vertex index
|
| 634 |
+
"""
|
| 635 |
+
vertex_positions, vertex_normals = self._align_mesh_with_main_axis(
|
| 636 |
+
vertex_positions, vertex_normals
|
| 637 |
+
)
|
| 638 |
+
bbox = torch.stack(
|
| 639 |
+
[vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values],
|
| 640 |
+
dim=0,
|
| 641 |
+
) # 2, 3
|
| 642 |
+
|
| 643 |
+
face_uv, face_index = self._box_assign_vertex_to_cube_face(
|
| 644 |
+
vertex_positions, vertex_normals, triangle_idxs, bbox
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
face_uv = self._rotate_uv_slices_consistent_space(
|
| 648 |
+
vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
assigned_atlas_index = self._assign_faces_uv_to_atlas_index(
|
| 652 |
+
vertex_positions, triangle_idxs, face_uv, face_index
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
offset_x, offset_y, div_x, div_y = self._find_slice_offset_and_scale(
|
| 656 |
+
assigned_atlas_index
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
placed_uv = self._distribute_individual_uvs_in_atlas(
|
| 660 |
+
face_uv,
|
| 661 |
+
assigned_atlas_index,
|
| 662 |
+
offset_x,
|
| 663 |
+
offset_y,
|
| 664 |
+
div_x,
|
| 665 |
+
div_y,
|
| 666 |
+
island_padding,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
return self._get_unique_face_uv(placed_uv)
|