vaibhavpandeyvpz commited on
Commit
eeef97b
·
0 Parent(s):

Deploy to HF spaces

Browse files
Files changed (41) hide show
  1. .gitattributes +2 -0
  2. .gitignore +77 -0
  3. README.md +124 -0
  4. app.py +472 -0
  5. load/tets/160_tets.npz +3 -0
  6. requirements.txt +40 -0
  7. sf3d/models/camera.py +32 -0
  8. sf3d/models/global_estimator/multi_head_estimator.py +118 -0
  9. sf3d/models/image_estimator/clip_based_estimator.py +168 -0
  10. sf3d/models/isosurface.py +229 -0
  11. sf3d/models/mesh.py +289 -0
  12. sf3d/models/network.py +213 -0
  13. sf3d/models/tokenizers/dinov2.py +1196 -0
  14. sf3d/models/tokenizers/image.py +101 -0
  15. sf3d/models/tokenizers/triplane.py +49 -0
  16. sf3d/models/transformers/attention.py +31 -0
  17. sf3d/models/transformers/backbone.py +515 -0
  18. sf3d/models/utils.py +236 -0
  19. sf3d/system.py +534 -0
  20. sf3d/utils.py +105 -0
  21. texture_baker/README.md +26 -0
  22. texture_baker/requirements.txt +2 -0
  23. texture_baker/setup.py +142 -0
  24. texture_baker/texture_baker/__init__.py +4 -0
  25. texture_baker/texture_baker/baker.py +86 -0
  26. texture_baker/texture_baker/csrc/baker.cpp +548 -0
  27. texture_baker/texture_baker/csrc/baker.h +203 -0
  28. texture_baker/texture_baker/csrc/baker_kernel.cu +306 -0
  29. texture_baker/texture_baker/csrc/baker_kernel.metal +170 -0
  30. texture_baker/texture_baker/csrc/baker_kernel.mm +260 -0
  31. uv_unwrapper/README.md +0 -0
  32. uv_unwrapper/requirements.txt +2 -0
  33. uv_unwrapper/setup.py +83 -0
  34. uv_unwrapper/uv_unwrapper/__init__.py +6 -0
  35. uv_unwrapper/uv_unwrapper/csrc/bvh.cpp +381 -0
  36. uv_unwrapper/uv_unwrapper/csrc/bvh.h +118 -0
  37. uv_unwrapper/uv_unwrapper/csrc/common.h +493 -0
  38. uv_unwrapper/uv_unwrapper/csrc/intersect.cpp +702 -0
  39. uv_unwrapper/uv_unwrapper/csrc/intersect.h +10 -0
  40. uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp +271 -0
  41. uv_unwrapper/uv_unwrapper/unwrap.py +669 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.npz filter=lfs diff=lfs merge=lfs -text
2
+ load/tets/160_tets.npz filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .venv
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # Jupyter Notebook
37
+ .ipynb_checkpoints
38
+
39
+ # Environment variables
40
+ .env
41
+ .env.local
42
+
43
+ # Model cache
44
+ .cache/
45
+ *.safetensors
46
+ *.ckpt
47
+ *.pt
48
+ *.pth
49
+
50
+ # Generated files
51
+ output/
52
+ *.glb
53
+ *.gltf
54
+ *.obj
55
+ *.ply
56
+
57
+ # Gradio temp files
58
+ gradio_cached_examples/
59
+ flagged/
60
+
61
+ # OS
62
+ .DS_Store
63
+ Thumbs.db
64
+
65
+ # Logs
66
+ *.log
67
+ logs/
68
+
69
+ # Temporary files
70
+ tmp/
71
+ temp/
72
+ *.tmp
73
+
74
+ # Hugging Face
75
+ .huggingface/
76
+
77
+ references/
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stable Diffusion Fast Text to 3D
3
+ emoji: 🎨
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 6.1.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ license_name: stabilityai-ai-community
12
+ license_link: https://huggingface.co/stabilityai/stable-fast-3d/blob/main/LICENSE.md
13
+ models:
14
+ - stabilityai/stable-diffusion-xl-base-1.0
15
+ - stabilityai/stable-fast-3d
16
+ gpu: true
17
+ ---
18
+
19
+ # Text to Image to 3D Generation
20
+
21
+ This Hugging Face Space provides a complete workflow to generate 3D models from text prompts using:
22
+
23
+ 1. **Stable Diffusion XL** - Generate high-quality images from text prompts
24
+ 2. **rembg** - Remove backgrounds from generated images
25
+ 3. **Stable Fast 3D** - Convert images to 3D mesh models
26
+
27
+ ## Features
28
+
29
+ - 🎨 **Text to Image**: Generate images using Stable Diffusion XL base model
30
+ - ✂️ **Background Removal**: Automatically remove backgrounds using rembg
31
+ - 🎮 **3D Generation**: Create textured 3D mesh models from images
32
+ - 🔄 **Step-by-step Workflow**: Review and confirm at each step
33
+ - ⚙️ **Customizable**: Adjust remeshing options, vertex count, and texture resolution
34
+
35
+ ## How to Use
36
+
37
+ 1. **Step 1 - Text to Image**:
38
+ - Enter your text prompt describing what you want to generate
39
+ - Optionally add a negative prompt to exclude unwanted elements
40
+ - Adjust the number of inference steps (more steps = higher quality, slower)
41
+ - Click "Generate Image" and wait for the result
42
+
43
+ 2. **Step 2 - Background Removal**:
44
+ - Review the generated image
45
+ - Click "Continue to Background Removal" to remove the background
46
+ - Preview the result with transparency
47
+
48
+ 3. **Step 3 - 3D Generation**:
49
+ - Review the background-removed image
50
+ - Adjust 3D generation settings:
51
+ - **Remeshing Option**: Choose "none", "triangle", or "quad" remeshing
52
+ - **Target Vertex Count**: Set to -1 for automatic, or specify a target count
53
+ - **Texture Size**: Choose texture resolution (512-2048)
54
+ - Click "Continue to 3D Generation" to create the 3D model
55
+ - Download your GLB file
56
+
57
+ ## Tips
58
+
59
+ - **Prompts**: Be descriptive and specific. Include style keywords like "3D render", "character", "stylized"
60
+ - **Background Removal**: Works best with clear foreground objects
61
+ - **3D Generation**:
62
+ - Use "none" remeshing for best quality
63
+ - Higher texture sizes produce better quality but take longer
64
+ - Vertex count of -1 uses the model's default
65
+
66
+ ## Technical Details
67
+
68
+ - **Models Used**:
69
+ - `stabilityai/stable-diffusion-xl-base-1.0` for text-to-image
70
+ - `rembg` for background removal
71
+ - `stabilityai/stable-fast-3d` for image-to-3D
72
+
73
+ - **Output Format**: GLB (glTF Binary) files compatible with most 3D software and viewers
74
+
75
+ - **GPU Resource Management**:
76
+ - Uses `@spaces.GPU()` decorators to properly manage GPU resources in Hugging Face Spaces
77
+ - GPU is allocated for text-to-image generation, background removal, and 3D mesh generation
78
+ - Ensures efficient GPU usage across the workflow
79
+
80
+ ## Requirements
81
+
82
+ This Space requires:
83
+ - GPU support (recommended for faster generation)
84
+ - Sufficient memory for model loading
85
+ - Internet connection for model downloads
86
+ - **Access to gated models**: The `stabilityai/stable-fast-3d` model is gated. You must:
87
+ 1. Accept the model's terms of use on [Hugging Face](https://huggingface.co/stabilityai/stable-fast-3d)
88
+ 2. The Space will automatically authenticate using the `HF_TOKEN` environment variable
89
+
90
+ ### Dependencies
91
+
92
+ This Space uses several models and packages:
93
+
94
+ 1. **Stable Diffusion XL**: Automatically downloaded from Hugging Face
95
+ 2. **rembg**: Installed via pip (included in requirements.txt)
96
+ 3. **Stable Fast 3D**:
97
+ - The model weights are downloaded from Hugging Face
98
+ - The `sf3d` Python package is included in this repository
99
+ - **Note**: This is a gated model - access must be granted by Stability AI
100
+ 4. **texture_baker and uv_unwrapper**:
101
+ - These packages are included in the repository
102
+ - They are automatically compiled and installed at runtime when the app starts
103
+ - Installation may take a few minutes on first run
104
+ - CUDA architecture is automatically detected or uses fallback architectures
105
+
106
+ ### Authentication
107
+
108
+ - **Hugging Face Token**: The Space automatically authenticates using the `HF_TOKEN` environment variable
109
+ - **Gated Model Access**: You must accept the terms and request access to `stabilityai/stable-fast-3d` on Hugging Face
110
+ - The authentication happens at startup, so all model downloads use the authenticated session
111
+
112
+ All required packages (`sf3d`, `texture_baker`, `uv_unwrapper`, and `load/tets`) are included in this repository, so no additional setup is needed.
113
+
114
+ ## License
115
+
116
+ This Space uses models with the following licenses:
117
+ - Stable Diffusion XL: [CreativeML Open RAIL++-M License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md)
118
+ - Stable Fast 3D: [Stability AI Community License](https://huggingface.co/stabilityai/stable-fast-3d/blob/main/LICENSE.md)
119
+
120
+ ## Credits
121
+
122
+ - [Stability AI](https://stability.ai/) for Stable Diffusion XL and Stable Fast 3D
123
+ - [rembg](https://github.com/danielgatis/rembg) for background removal
124
+ - Built with [Gradio](https://gradio.app/)
app.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ import os
4
+ import tempfile
5
+ import time
6
+ from contextlib import nullcontext
7
+ from functools import lru_cache
8
+ from typing import Any
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import rembg
13
+ from diffusers import DiffusionPipeline
14
+ from gradio_litmodel3d import LitModel3D
15
+ from huggingface_hub import login
16
+ from PIL import Image
17
+
18
+ # Authenticate with Hugging Face using token from environment
19
+ # HF_TOKEN is automatically available in Hugging Face Spaces
20
+ hf_token = os.environ.get("HF_TOKEN")
21
+ if hf_token:
22
+ # Login to Hugging Face - this stores the token for all HF Hub operations
23
+ login(token=hf_token)
24
+ # Also ensure it's set as environment variable for any libraries that check it directly
25
+ os.environ["HF_TOKEN"] = hf_token
26
+ print("Authenticated with Hugging Face")
27
+ else:
28
+ print("Warning: HF_TOKEN not found. Gated models may not be accessible.")
29
+ print("Please ensure HF_TOKEN is set in your Space's secrets.")
30
+
31
+ if not torch.cuda.is_available():
32
+ raise Exception("CUDA is not available")
33
+
34
+ # Set environment variables for building texture_baker and uv_unwrapper
35
+ os.environ["USE_CUDA"] = "1"
36
+ os.environ["USE_NATIVE_ARCH"] = "0" # Disable native arch to avoid build issues
37
+
38
+ # Set CUDA architecture list to avoid detection issues
39
+ # PyTorch's build system fails when it can't detect GPU architectures
40
+ # Setting TORCH_CUDA_ARCH_LIST explicitly prevents this error
41
+ if torch.cuda.is_available():
42
+ try:
43
+ # Try to get the actual compute capability
44
+ compute_cap = torch.cuda.get_device_capability(0)
45
+ cuda_arch = f"{compute_cap[0]}.{compute_cap[1]}"
46
+ os.environ["TORCH_CUDA_ARCH_LIST"] = cuda_arch
47
+ print(
48
+ f"Detected CUDA capability: {cuda_arch}, setting TORCH_CUDA_ARCH_LIST={cuda_arch}"
49
+ )
50
+ except Exception as e:
51
+ # Fallback to common architectures if detection fails
52
+ # Include multiple architectures to support various GPU models
53
+ fallback_archs = "7.0;7.5;8.0;8.6;8.9;9.0"
54
+ os.environ["TORCH_CUDA_ARCH_LIST"] = fallback_archs
55
+ print(
56
+ f"Could not detect CUDA capability: {e}, using fallback architectures: {fallback_archs}"
57
+ )
58
+ else:
59
+ # Should not happen since we check above, but just in case
60
+ print("Warning: CUDA not available but trying to build with CUDA support")
61
+
62
+ os.system(
63
+ "USE_CUDA=1 USE_NATIVE_ARCH=0 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper"
64
+ )
65
+
66
+ import sf3d.utils as sf3d_utils
67
+ from sf3d.system import SF3D
68
+
69
+ # Set up environment
70
+ os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
71
+
72
+ # Initialize rembg session
73
+ rembg_session = rembg.new_session()
74
+
75
+ # Constants for 3D generation
76
+ COND_WIDTH = 512
77
+ COND_HEIGHT = 512
78
+ COND_DISTANCE = 1.6
79
+ COND_FOVY_DEG = 40
80
+ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
81
+
82
+ # Cached. Doesn't change
83
+ c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
84
+ intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
85
+ COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
86
+ )
87
+
88
+ generated_files = []
89
+
90
+ # Initialize device and SF3D model (like official app)
91
+ device = sf3d_utils.get_device()
92
+
93
+ # SF3D model - initialized at startup like official app
94
+ # Token is automatically used after login() call above
95
+ sf3d_model = SF3D.from_pretrained(
96
+ "stabilityai/stable-fast-3d",
97
+ config_name="config.yaml",
98
+ weight_name="model.safetensors",
99
+ )
100
+ sf3d_model.eval()
101
+ sf3d_model = sf3d_model.to(device)
102
+
103
+ # SDXL pipeline - lazy loaded to save memory
104
+ sd_pipeline = None
105
+
106
+
107
+ def initialize_sdxl():
108
+ """Initialize SDXL pipeline on first use."""
109
+ global sd_pipeline, device
110
+
111
+ if sd_pipeline is None:
112
+ print("Loading Stable Diffusion XL model...")
113
+ sd_pipeline = DiffusionPipeline.from_pretrained(
114
+ "stabilityai/stable-diffusion-xl-base-1.0",
115
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
116
+ use_safetensors=True,
117
+ variant="fp16" if device == "cuda" else None,
118
+ )
119
+ if device == "cuda":
120
+ sd_pipeline = sd_pipeline.to(device)
121
+ # Enable memory efficient attention if available
122
+ try:
123
+ sd_pipeline.enable_xformers_memory_efficient_attention()
124
+ except:
125
+ pass
126
+ elif device == "mps":
127
+ sd_pipeline = sd_pipeline.to(device)
128
+ else:
129
+ sd_pipeline.enable_model_cpu_offload()
130
+ print("SDXL model loaded!")
131
+
132
+ return sd_pipeline
133
+
134
+
135
+ @spaces.GPU()
136
+ def generate_text_to_image(
137
+ prompt: str, negative_prompt: str = "", num_inference_steps: int = 30
138
+ ):
139
+ """Generate image from text prompt using SDXL."""
140
+ pipeline = initialize_sdxl()
141
+
142
+ print(f"Generating image from prompt: {prompt}")
143
+
144
+ # Generate image
145
+ with torch.no_grad():
146
+ if device == "cuda":
147
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
148
+ image = pipeline(
149
+ prompt=prompt,
150
+ negative_prompt=negative_prompt if negative_prompt else None,
151
+ num_inference_steps=num_inference_steps,
152
+ ).images[0]
153
+ else:
154
+ image = pipeline(
155
+ prompt=prompt,
156
+ negative_prompt=negative_prompt if negative_prompt else None,
157
+ num_inference_steps=num_inference_steps,
158
+ ).images[0]
159
+
160
+ return image
161
+
162
+
163
+ @spaces.GPU()
164
+ def remove_background_from_image(image: Image.Image) -> Image.Image:
165
+ """Remove background from image using rembg."""
166
+ print("Removing background...")
167
+ result = rembg.remove(image, session=rembg_session)
168
+ return result
169
+
170
+
171
+ def create_batch(input_image: Image) -> dict[str, Any]:
172
+ """Create batch for SF3D model - matches official app structure."""
173
+ img_cond = (
174
+ torch.from_numpy(
175
+ np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
176
+ / 255.0
177
+ )
178
+ .float()
179
+ .clip(0, 1)
180
+ )
181
+ mask_cond = img_cond[:, :, -1:]
182
+ rgb_cond = torch.lerp(
183
+ torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
184
+ )
185
+
186
+ batch_elem = {
187
+ "rgb_cond": rgb_cond,
188
+ "mask_cond": mask_cond,
189
+ "c2w_cond": c2w_cond.unsqueeze(0),
190
+ "intrinsic_cond": intrinsic.unsqueeze(0),
191
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
192
+ }
193
+ # Add batch dim
194
+ batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
195
+ return batched
196
+
197
+
198
+ def run_model(input_image, remesh_option, vertex_count, texture_size):
199
+ """Run SF3D model - matches official app structure."""
200
+ start = time.time()
201
+ with torch.no_grad():
202
+ with (
203
+ torch.autocast(device_type=device, dtype=torch.bfloat16)
204
+ if "cuda" in device
205
+ else nullcontext()
206
+ ):
207
+ model_batch = create_batch(input_image)
208
+ model_batch = {k: v.to(device) for k, v in model_batch.items()}
209
+ trimesh_mesh, _glob_dict = sf3d_model.generate_mesh(
210
+ model_batch, texture_size, remesh_option.lower(), vertex_count
211
+ )
212
+ trimesh_mesh = trimesh_mesh[0]
213
+
214
+ # Create new tmp file
215
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
216
+
217
+ trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
218
+ generated_files.append(tmp_file.name)
219
+
220
+ print("Generation took:", time.time() - start, "s")
221
+
222
+ return tmp_file.name
223
+
224
+
225
+ @spaces.GPU()
226
+ def generate_3d_from_image(
227
+ input_image: Image.Image,
228
+ remesh_option: str = "none",
229
+ vertex_count: int = -1,
230
+ texture_size: int = 1024,
231
+ ) -> str:
232
+ """Generate 3D mesh from image using SF3D."""
233
+ # Resize foreground if needed (like official app)
234
+ foreground_ratio = 0.85
235
+ processed_image = sf3d_utils.resize_foreground(
236
+ input_image, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
237
+ )
238
+
239
+ return run_model(processed_image, remesh_option, vertex_count, texture_size)
240
+
241
+
242
+ @lru_cache
243
+ def checkerboard(squares: int, size: int, min_value: float = 0.5):
244
+ """Create checkerboard pattern for transparency preview."""
245
+ base = np.zeros((squares, squares)) + min_value
246
+ base[1::2, ::2] = 1
247
+ base[::2, 1::2] = 1
248
+
249
+ repeat_mult = size // squares
250
+ return (
251
+ base.repeat(repeat_mult, axis=0)
252
+ .repeat(repeat_mult, axis=1)[:, :, None]
253
+ .repeat(3, axis=-1)
254
+ )
255
+
256
+
257
+ def show_mask_preview(input_image: Image.Image) -> Image.Image:
258
+ """Show image with checkerboard background for transparency preview."""
259
+ img_numpy = np.array(input_image)
260
+ alpha = img_numpy[:, :, 3] / 255.0
261
+ chkb = checkerboard(32, 512) * 255
262
+ new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
263
+ return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
264
+
265
+
266
+ # Gradio Interface Functions
267
+ def step1_generate_image(prompt, negative_prompt, num_steps):
268
+ """Step 1: Generate image from text."""
269
+ if not prompt:
270
+ return None, gr.update(visible=False), "Please enter a prompt"
271
+
272
+ try:
273
+ image = generate_text_to_image(prompt, negative_prompt, num_steps)
274
+ return (
275
+ image,
276
+ gr.update(visible=True, value="Continue to Background Removal"),
277
+ "Image generated successfully! Review and continue to remove background.",
278
+ )
279
+ except Exception as e:
280
+ return None, gr.update(visible=False), f"Error generating image: {str(e)}"
281
+
282
+
283
+ def step2_remove_background(image):
284
+ """Step 2: Remove background from image."""
285
+ if image is None:
286
+ return None, None, gr.update(visible=False), "Please generate an image first"
287
+
288
+ try:
289
+ # Convert to RGB if needed
290
+ if image.mode != "RGB":
291
+ image = image.convert("RGB")
292
+
293
+ bg_removed = remove_background_from_image(image)
294
+ preview = show_mask_preview(bg_removed)
295
+
296
+ return (
297
+ bg_removed,
298
+ preview,
299
+ gr.update(visible=True, value="Continue to 3D Generation"),
300
+ "Background removed successfully! Review and continue to generate 3D model.",
301
+ )
302
+ except Exception as e:
303
+ return (
304
+ None,
305
+ None,
306
+ gr.update(visible=False),
307
+ f"Error removing background: {str(e)}",
308
+ )
309
+
310
+
311
+ def step3_generate_3d(image_with_bg_removed, remesh_option, vertex_count, texture_size):
312
+ """Step 3: Generate 3D model from image."""
313
+ if image_with_bg_removed is None:
314
+ return gr.update(value=None, visible=False), "Please remove background first"
315
+
316
+ try:
317
+ glb_file = generate_3d_from_image(
318
+ image_with_bg_removed, remesh_option, vertex_count, texture_size
319
+ )
320
+
321
+ return (
322
+ gr.update(value=glb_file, visible=True),
323
+ "3D model generated successfully! You can download it below.",
324
+ )
325
+ except Exception as e:
326
+ return (
327
+ gr.update(value=None, visible=False),
328
+ f"Error generating 3D model: {str(e)}",
329
+ )
330
+
331
+
332
+ # Create Gradio Interface
333
+ with gr.Blocks(title="Text to Image to 3D") as demo:
334
+ gr.Markdown(
335
+ """
336
+ # Text to Image to 3D Generation
337
+
338
+ This app allows you to generate 3D models from text prompts in three steps:
339
+ 1. **Text to Image**: Generate an image using Stable Diffusion XL
340
+ 2. **Background Removal**: Remove the background using rembg
341
+ 3. **3D Generation**: Create a 3D mesh model using Stable Fast 3D
342
+
343
+ **Instructions:**
344
+ - Enter your text prompt and generate an image
345
+ - Review the generated image and continue to remove the background
346
+ - Review the background-removed image and continue to generate the 3D model
347
+ - Download your 3D model as a GLB file
348
+ """
349
+ )
350
+
351
+ with gr.Row():
352
+ with gr.Column(scale=1):
353
+ gr.Markdown("### Step 1: Text to Image")
354
+ prompt = gr.Textbox(
355
+ label="Prompt",
356
+ placeholder="A cute robot character, 3D render, colorful",
357
+ lines=2,
358
+ )
359
+ negative_prompt = gr.Textbox(
360
+ label="Negative Prompt (optional)",
361
+ placeholder="blurry, low quality, distorted",
362
+ lines=2,
363
+ )
364
+ num_steps = gr.Slider(
365
+ label="Number of Inference Steps",
366
+ minimum=20,
367
+ maximum=50,
368
+ value=30,
369
+ step=5,
370
+ )
371
+ generate_btn = gr.Button("Generate Image", variant="primary")
372
+ step1_status = gr.Textbox(label="Status", interactive=False)
373
+
374
+ step1_image = gr.Image(label="Generated Image", type="pil")
375
+ step1_continue_btn = gr.Button(
376
+ "Continue to Background Removal",
377
+ visible=False,
378
+ variant="secondary",
379
+ )
380
+
381
+ with gr.Column(scale=1):
382
+ gr.Markdown("### Step 2: Background Removal")
383
+ step2_image = gr.Image(label="Image with Background Removed", type="pil")
384
+ step2_preview = gr.Image(
385
+ label="Preview (with transparency)",
386
+ type="pil",
387
+ visible=False,
388
+ )
389
+ step2_status = gr.Textbox(label="Status", interactive=False)
390
+ step2_continue_btn = gr.Button(
391
+ "Continue to 3D Generation",
392
+ visible=False,
393
+ variant="secondary",
394
+ )
395
+
396
+ with gr.Column(scale=1):
397
+ gr.Markdown("### Step 3: 3D Generation")
398
+ remesh_option = gr.Radio(
399
+ choices=["none", "triangle", "quad"],
400
+ label="Remeshing Option",
401
+ value="none",
402
+ )
403
+ vertex_count = gr.Slider(
404
+ label="Target Vertex Count (-1 for auto)",
405
+ minimum=-1,
406
+ maximum=20000,
407
+ value=-1,
408
+ step=100,
409
+ )
410
+ texture_size = gr.Slider(
411
+ label="Texture Size",
412
+ minimum=512,
413
+ maximum=2048,
414
+ value=1024,
415
+ step=256,
416
+ )
417
+ step3_generate_btn = gr.Button("Generate 3D Model", variant="primary")
418
+ step3_status = gr.Textbox(label="Status", interactive=False)
419
+
420
+ step3_output = LitModel3D(
421
+ label="3D Model",
422
+ visible=False,
423
+ clear_color=[0.0, 0.0, 0.0, 0.0],
424
+ )
425
+
426
+ # State variables
427
+ step2_image_state = gr.State()
428
+
429
+ # Event handlers
430
+ generate_btn.click(
431
+ fn=step1_generate_image,
432
+ inputs=[prompt, negative_prompt, num_steps],
433
+ outputs=[step1_image, step1_continue_btn, step1_status],
434
+ )
435
+
436
+ step1_continue_btn.click(
437
+ fn=step2_remove_background,
438
+ inputs=[step1_image],
439
+ outputs=[step2_image, step2_preview, step2_continue_btn, step2_status],
440
+ ).then(
441
+ fn=lambda img: img,
442
+ inputs=[step2_image],
443
+ outputs=[step2_image_state],
444
+ )
445
+
446
+ step2_continue_btn.click(
447
+ fn=step3_generate_3d,
448
+ inputs=[step2_image_state, remesh_option, vertex_count, texture_size],
449
+ outputs=[step3_output, step3_status],
450
+ )
451
+
452
+ # Update preview when image changes
453
+ step2_image.change(
454
+ fn=show_mask_preview,
455
+ inputs=[step2_image],
456
+ outputs=[step2_preview],
457
+ ).then(
458
+ fn=lambda: gr.update(visible=True),
459
+ outputs=[step2_preview],
460
+ )
461
+
462
+
463
+ if __name__ == "__main__":
464
+ # Delete previous gradio temp dir folder (like official app)
465
+ if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
466
+ print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
467
+ import shutil
468
+
469
+ shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
470
+
471
+ demo.queue()
472
+ demo.launch(share=False)
load/tets/160_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
3
+ size 15408790
requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wheel
2
+ setuptools==69.5.1
3
+
4
+ # Core dependencies
5
+ torch==2.5.1
6
+ torchvision==0.20.1
7
+ numpy==1.26.4
8
+ Pillow>=9.5.0
9
+
10
+ # Stable Diffusion XL
11
+ diffusers>=0.21.0
12
+ transformers==4.42.3
13
+ accelerate>=0.20.0
14
+ safetensors>=0.3.0
15
+ invisible-watermark>=0.2.0
16
+
17
+ # Background removal
18
+ rembg[gpu]==2.0.57; sys_platform != 'darwin'
19
+ rembg==2.0.57; sys_platform == 'darwin'
20
+
21
+ # Stable Fast 3D dependencies
22
+ einops==0.7.0
23
+ jaxtyping==0.2.31
24
+ omegaconf==2.3.0
25
+ open_clip_torch==2.24.0
26
+ trimesh==4.4.1
27
+ huggingface-hub>=0.23.2,<1.0
28
+ pynanoinstantmeshes==0.0.3
29
+ gpytoolbox==0.2.0
30
+
31
+ # Gradio and UI
32
+ gradio==4.41.0
33
+ gradio-litmodel3d==0.0.1
34
+
35
+ # Additional utilities
36
+ tqdm>=4.65.0
37
+
38
+ # (HF hack) These are installed at runtime in gradio_app.py
39
+ # ./texture_baker/
40
+ # ./uv_unwrapper/
sf3d/models/camera.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from sf3d.models.utils import BaseModule
8
+
9
+
10
+ class LinearCameraEmbedder(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ in_channels: int = 25
14
+ out_channels: int = 768
15
+ conditions: List[str] = field(default_factory=list)
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
+
22
+ def forward(self, **kwargs):
23
+ cond_tensors = []
24
+ for cond_name in self.cfg.conditions:
25
+ assert cond_name in kwargs
26
+ cond = kwargs[cond_name]
27
+ # cond in shape (B, Nv, ...)
28
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
29
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
30
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
31
+ embedding = self.linear(cond_tensor)
32
+ return embedding
sf3d/models/global_estimator/multi_head_estimator.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import torch.nn as nn
5
+ from jaxtyping import Float
6
+ from torch import Tensor
7
+
8
+ from sf3d.models.network import get_activation
9
+ from sf3d.models.utils import BaseModule
10
+
11
+
12
+ @dataclass
13
+ class HeadSpec:
14
+ name: str
15
+ out_channels: int
16
+ n_hidden_layers: int
17
+ output_activation: Optional[str] = None
18
+ output_bias: float = 0.0
19
+ add_to_decoder_features: bool = False
20
+ shape: Optional[list[int]] = None
21
+
22
+
23
+ class MultiHeadEstimator(BaseModule):
24
+ @dataclass
25
+ class Config(BaseModule.Config):
26
+ triplane_features: int = 1024
27
+
28
+ n_layers: int = 2
29
+ hidden_features: int = 512
30
+ activation: str = "relu"
31
+
32
+ pool: str = "max"
33
+ # Literal["mean", "max"] = "mean" # noqa: F821
34
+
35
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
36
+
37
+ cfg: Config
38
+
39
+ def configure(self):
40
+ layers = []
41
+ cur_features = self.cfg.triplane_features * 3
42
+ for _ in range(self.cfg.n_layers):
43
+ layers.append(
44
+ nn.Conv2d(
45
+ cur_features,
46
+ self.cfg.hidden_features,
47
+ kernel_size=3,
48
+ padding=0,
49
+ stride=2,
50
+ )
51
+ )
52
+ layers.append(self.make_activation(self.cfg.activation))
53
+
54
+ cur_features = self.cfg.hidden_features
55
+
56
+ self.layers = nn.Sequential(*layers)
57
+
58
+ assert len(self.cfg.heads) > 0
59
+ heads = {}
60
+ for head in self.cfg.heads:
61
+ head_layers = []
62
+ for i in range(head.n_hidden_layers):
63
+ head_layers += [
64
+ nn.Linear(
65
+ self.cfg.hidden_features,
66
+ self.cfg.hidden_features,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ head_layers += [
71
+ nn.Linear(
72
+ self.cfg.hidden_features,
73
+ head.out_channels,
74
+ ),
75
+ ]
76
+ heads[head.name] = nn.Sequential(*head_layers)
77
+ self.heads = nn.ModuleDict(heads)
78
+
79
+ def make_activation(self, activation):
80
+ if activation == "relu":
81
+ return nn.ReLU(inplace=True)
82
+ elif activation == "silu":
83
+ return nn.SiLU(inplace=True)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ def forward(
88
+ self,
89
+ triplane: Float[Tensor, "B 3 F Ht Wt"],
90
+ ) -> dict[str, Any]:
91
+ x = self.layers(
92
+ triplane.reshape(
93
+ triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
94
+ )
95
+ )
96
+
97
+ if self.cfg.pool == "max":
98
+ x = x.amax(dim=[-2, -1])
99
+ elif self.cfg.pool == "mean":
100
+ x = x.mean(dim=[-2, -1])
101
+ else:
102
+ raise NotImplementedError
103
+
104
+ out = {
105
+ ("decoder_" if head.add_to_decoder_features else "")
106
+ + head.name: get_activation(head.output_activation)(
107
+ self.heads[head.name](x) + head.output_bias
108
+ )
109
+ for head in self.cfg.heads
110
+ }
111
+ for head in self.cfg.heads:
112
+ if head.shape:
113
+ head_name = (
114
+ "decoder_" if head.add_to_decoder_features else ""
115
+ ) + head.name
116
+ out[head_name] = out[head_name].reshape(*head.shape)
117
+
118
+ return out
sf3d/models/image_estimator/clip_based_estimator.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import open_clip
5
+ import torch
6
+ import torch.nn as nn
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+ from torchvision.transforms import Normalize
10
+
11
+ from sf3d.models.network import get_activation
12
+ from sf3d.models.utils import BaseModule
13
+
14
+
15
+ @dataclass
16
+ class HeadSpec:
17
+ name: str
18
+ out_channels: int
19
+ n_hidden_layers: int
20
+ output_activation: Optional[str] = None
21
+ output_bias: float = 0.0
22
+ add_to_decoder_features: bool = False
23
+ shape: Optional[list[int]] = None
24
+
25
+
26
+ class ClipBasedHeadEstimator(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ model: str = "ViT-B-32"
30
+ pretrain: str = "laion2b_s34b_b79k"
31
+
32
+ distribution: str = "beta"
33
+
34
+ # ["mean", "mode", "sample", "sample_mean"]
35
+ distribution_eval: str = "mode"
36
+
37
+ activation: str = "relu"
38
+ hidden_features: int = 512
39
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
40
+
41
+ cfg: Config
42
+
43
+ def configure(self):
44
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
45
+ self.cfg.model, pretrained=self.cfg.pretrain
46
+ )
47
+ self.model.eval()
48
+
49
+ # Do not add the weights in self.model to the optimizer
50
+ for param in self.model.parameters():
51
+ param.requires_grad = False
52
+
53
+ assert len(self.cfg.heads) > 0
54
+ heads = {}
55
+ for head in self.cfg.heads:
56
+ head_layers = []
57
+
58
+ for i in range(head.n_hidden_layers):
59
+ head_layers += [
60
+ nn.Linear(
61
+ self.cfg.hidden_features,
62
+ self.cfg.hidden_features,
63
+ ),
64
+ self.make_activation(self.cfg.activation),
65
+ ]
66
+
67
+ head_layers = [nn.Sequential(*head_layers)]
68
+ head_layers += [
69
+ nn.Sequential(
70
+ nn.Linear(
71
+ self.cfg.hidden_features,
72
+ self.cfg.hidden_features,
73
+ ),
74
+ self.make_activation(self.cfg.activation),
75
+ nn.Linear(self.cfg.hidden_features, 1),
76
+ )
77
+ for _ in range(2)
78
+ ]
79
+ heads[head.name] = nn.ModuleList(head_layers)
80
+ self.heads = nn.ModuleDict(heads)
81
+
82
+ def make_activation(self, activation):
83
+ if activation == "relu":
84
+ return nn.ReLU(inplace=True)
85
+ elif activation == "silu":
86
+ return nn.SiLU(inplace=True)
87
+ else:
88
+ raise NotImplementedError
89
+
90
+ def forward(
91
+ self,
92
+ cond_image: Float[Tensor, "B 1 H W 3"],
93
+ sample: bool = True,
94
+ ) -> dict[str, Any]:
95
+ # Run the model
96
+ # Resize cond_image to 224
97
+ cond_image = nn.functional.interpolate(
98
+ cond_image.flatten(0, 1).permute(0, 3, 1, 2).contiguous(),
99
+ size=(224, 224),
100
+ mode="bilinear",
101
+ align_corners=False,
102
+ )
103
+ cond_image = Normalize(
104
+ mean=open_clip.constants.OPENAI_DATASET_MEAN,
105
+ std=open_clip.constants.OPENAI_DATASET_STD,
106
+ )(cond_image)
107
+ image_features = self.model.encode_image(cond_image)
108
+
109
+ # Run the heads
110
+ outputs = {}
111
+
112
+ for head_dict in self.cfg.heads:
113
+ head_name = head_dict.name
114
+ shared_head, d1_h, d2_h = self.heads[head_name]
115
+ shared_features = shared_head(image_features)
116
+ d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
117
+ if self.cfg.distribution == "normal":
118
+ mean = d1
119
+ var = d2
120
+ if mean.shape[-1] == 1:
121
+ outputs[head_name] = torch.distributions.Normal(
122
+ mean + head_dict.output_bias,
123
+ torch.nn.functional.softplus(var),
124
+ )
125
+ else:
126
+ outputs[head_name] = torch.distributions.MultivariateNormal(
127
+ mean + head_dict.output_bias,
128
+ torch.nn.functional.softplus(var).diag_embed(),
129
+ )
130
+ elif self.cfg.distribution == "beta":
131
+ outputs[head_name] = torch.distributions.Beta(
132
+ torch.nn.functional.softplus(d1 + head_dict.output_bias),
133
+ torch.nn.functional.softplus(d2 + head_dict.output_bias),
134
+ )
135
+ else:
136
+ raise NotImplementedError
137
+
138
+ if sample:
139
+ for head_dict in self.cfg.heads:
140
+ head_name = head_dict.name
141
+ dist = outputs[head_name]
142
+
143
+ if self.cfg.distribution_eval == "mean":
144
+ out = dist.mean
145
+ elif self.cfg.distribution_eval == "mode":
146
+ out = dist.mode
147
+ elif self.cfg.distribution_eval == "sample_mean":
148
+ out = dist.sample([10]).mean(-1)
149
+ else:
150
+ # use rsample if gradient is needed
151
+ out = dist.rsample() if self.training else dist.sample()
152
+
153
+ outputs[head_name] = get_activation(head_dict.output_activation)(out)
154
+ outputs[f"{head_name}_dist"] = dist
155
+
156
+ for head in self.cfg.heads:
157
+ if head.shape:
158
+ if not sample:
159
+ raise ValueError(
160
+ "Cannot reshape non-sampled probabilisitic outputs"
161
+ )
162
+ outputs[head.name] = outputs[head.name].reshape(*head.shape)
163
+
164
+ if head.add_to_decoder_features:
165
+ outputs[f"decoder_{head.name}"] = outputs[head.name]
166
+ del outputs[head.name]
167
+
168
+ return outputs
sf3d/models/isosurface.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from .mesh import Mesh
10
+
11
+
12
+ class IsosurfaceHelper(nn.Module):
13
+ points_range: Tuple[float, float] = (0, 1)
14
+
15
+ @property
16
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
+ raise NotImplementedError
18
+
19
+ @property
20
+ def requires_instance_per_batch(self) -> bool:
21
+ return False
22
+
23
+
24
+ class MarchingTetrahedraHelper(IsosurfaceHelper):
25
+ def __init__(self, resolution: int, tets_path: str):
26
+ super().__init__()
27
+ self.resolution = resolution
28
+ self.tets_path = tets_path
29
+
30
+ self.triangle_table: Float[Tensor, "..."]
31
+ self.register_buffer(
32
+ "triangle_table",
33
+ torch.as_tensor(
34
+ [
35
+ [-1, -1, -1, -1, -1, -1],
36
+ [1, 0, 2, -1, -1, -1],
37
+ [4, 0, 3, -1, -1, -1],
38
+ [1, 4, 2, 1, 3, 4],
39
+ [3, 1, 5, -1, -1, -1],
40
+ [2, 3, 0, 2, 5, 3],
41
+ [1, 4, 0, 1, 5, 4],
42
+ [4, 2, 5, -1, -1, -1],
43
+ [4, 5, 2, -1, -1, -1],
44
+ [4, 1, 0, 4, 5, 1],
45
+ [3, 2, 0, 3, 5, 2],
46
+ [1, 3, 5, -1, -1, -1],
47
+ [4, 1, 2, 4, 3, 1],
48
+ [3, 0, 4, -1, -1, -1],
49
+ [2, 0, 1, -1, -1, -1],
50
+ [-1, -1, -1, -1, -1, -1],
51
+ ],
52
+ dtype=torch.long,
53
+ ),
54
+ persistent=False,
55
+ )
56
+ self.num_triangles_table: Integer[Tensor, "..."]
57
+ self.register_buffer(
58
+ "num_triangles_table",
59
+ torch.as_tensor(
60
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
61
+ ),
62
+ persistent=False,
63
+ )
64
+ self.base_tet_edges: Integer[Tensor, "..."]
65
+ self.register_buffer(
66
+ "base_tet_edges",
67
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
68
+ persistent=False,
69
+ )
70
+
71
+ tets = np.load(self.tets_path)
72
+ self._grid_vertices: Float[Tensor, "..."]
73
+ self.register_buffer(
74
+ "_grid_vertices",
75
+ torch.from_numpy(tets["vertices"]).float(),
76
+ persistent=False,
77
+ )
78
+ self.indices: Integer[Tensor, "..."]
79
+ self.register_buffer(
80
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
81
+ )
82
+
83
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
84
+
85
+ center_indices, boundary_indices = self.get_center_boundary_index(
86
+ self._grid_vertices
87
+ )
88
+ self.center_indices: Integer[Tensor, "..."]
89
+ self.register_buffer("center_indices", center_indices, persistent=False)
90
+ self.boundary_indices: Integer[Tensor, "..."]
91
+ self.register_buffer("boundary_indices", boundary_indices, persistent=False)
92
+
93
+ def get_center_boundary_index(self, verts):
94
+ magn = torch.sum(verts**2, dim=-1)
95
+
96
+ center_idx = torch.argmin(magn)
97
+ boundary_neg = verts == verts.max()
98
+ boundary_pos = verts == verts.min()
99
+
100
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
101
+ boundary = torch.sum(boundary.float(), dim=-1)
102
+
103
+ boundary_idx = torch.nonzero(boundary)
104
+ return center_idx, boundary_idx.squeeze(dim=-1)
105
+
106
+ def normalize_grid_deformation(
107
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
108
+ ) -> Float[Tensor, "Nv 3"]:
109
+ return (
110
+ (self.points_range[1] - self.points_range[0])
111
+ / self.resolution # half tet size is approximately 1 / self.resolution
112
+ * torch.tanh(grid_vertex_offsets)
113
+ ) # FIXME: hard-coded activation
114
+
115
+ @property
116
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
117
+ return self._grid_vertices
118
+
119
+ @property
120
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
121
+ if self._all_edges is None:
122
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
123
+ edges = torch.tensor(
124
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
125
+ dtype=torch.long,
126
+ device=self.indices.device,
127
+ )
128
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
129
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
130
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
131
+ self._all_edges = _all_edges
132
+ return self._all_edges
133
+
134
+ def sort_edges(self, edges_ex2):
135
+ with torch.no_grad():
136
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
137
+ order = order.unsqueeze(dim=1)
138
+
139
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
140
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
141
+
142
+ return torch.stack([a, b], -1)
143
+
144
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
145
+ with torch.no_grad():
146
+ occ_n = sdf_n > 0
147
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
148
+ occ_sum = torch.sum(occ_fx4, -1)
149
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
150
+ occ_sum = occ_sum[valid_tets]
151
+
152
+ # find all vertices
153
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
154
+ all_edges = self.sort_edges(all_edges)
155
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
156
+
157
+ unique_edges = unique_edges.long()
158
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
159
+ mapping = (
160
+ torch.ones(
161
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
162
+ )
163
+ * -1
164
+ )
165
+ mapping[mask_edges] = torch.arange(
166
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
167
+ )
168
+ idx_map = mapping[idx_map] # map edges to verts
169
+
170
+ interp_v = unique_edges[mask_edges]
171
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
172
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
173
+ edges_to_interp_sdf[:, -1] *= -1
174
+
175
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
176
+
177
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
178
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
179
+
180
+ idx_map = idx_map.reshape(-1, 6)
181
+
182
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
183
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
184
+ num_triangles = self.num_triangles_table[tetindex]
185
+
186
+ # Generate triangle indices
187
+ faces = torch.cat(
188
+ (
189
+ torch.gather(
190
+ input=idx_map[num_triangles == 1],
191
+ dim=1,
192
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
193
+ ).reshape(-1, 3),
194
+ torch.gather(
195
+ input=idx_map[num_triangles == 2],
196
+ dim=1,
197
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
198
+ ).reshape(-1, 3),
199
+ ),
200
+ dim=0,
201
+ )
202
+
203
+ return verts, faces
204
+
205
+ def forward(
206
+ self,
207
+ level: Float[Tensor, "N3 1"],
208
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
209
+ ) -> Mesh:
210
+ if deformation is not None:
211
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
212
+ deformation
213
+ )
214
+ else:
215
+ grid_vertices = self.grid_vertices
216
+
217
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
218
+
219
+ mesh = Mesh(
220
+ v_pos=v_pos,
221
+ t_pos_idx=t_pos_idx,
222
+ # extras
223
+ grid_vertices=grid_vertices,
224
+ tet_edges=self.all_edges,
225
+ grid_level=level,
226
+ grid_deformation=deformation,
227
+ )
228
+
229
+ return mesh
sf3d/models/mesh.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any, Dict, Optional
5
+
6
+ import gpytoolbox
7
+ import numpy as np
8
+ import pynanoinstantmeshes
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import trimesh
12
+ from jaxtyping import Float, Integer
13
+ from torch import Tensor
14
+
15
+ from sf3d.models.utils import dot
16
+
17
+ try:
18
+ from uv_unwrapper import Unwrapper
19
+ except ImportError:
20
+ import logging
21
+
22
+ logging.warning(
23
+ "Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
24
+ )
25
+ # Exit early to avoid further errors
26
+ raise ImportError("uv_unwrapper not found")
27
+
28
+
29
+ class Mesh:
30
+ def __init__(
31
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
32
+ ) -> None:
33
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
34
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
35
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
36
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
37
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
38
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
39
+ self.extras: Dict[str, Any] = {}
40
+ for k, v in kwargs.items():
41
+ self.add_extra(k, v)
42
+
43
+ self.unwrapper = Unwrapper()
44
+
45
+ def add_extra(self, k, v) -> None:
46
+ self.extras[k] = v
47
+
48
+ @property
49
+ def requires_grad(self):
50
+ return self.v_pos.requires_grad
51
+
52
+ @property
53
+ def v_nrm(self):
54
+ if self._v_nrm is None:
55
+ self._v_nrm = self._compute_vertex_normal()
56
+ return self._v_nrm
57
+
58
+ @property
59
+ def v_tng(self):
60
+ if self._v_tng is None:
61
+ self._v_tng = self._compute_vertex_tangent()
62
+ return self._v_tng
63
+
64
+ @property
65
+ def v_tex(self):
66
+ if self._v_tex is None:
67
+ self.unwrap_uv()
68
+ return self._v_tex
69
+
70
+ @property
71
+ def edges(self):
72
+ if self._edges is None:
73
+ self._edges = self._compute_edges()
74
+ return self._edges
75
+
76
+ def _compute_vertex_normal(self):
77
+ i0 = self.t_pos_idx[:, 0]
78
+ i1 = self.t_pos_idx[:, 1]
79
+ i2 = self.t_pos_idx[:, 2]
80
+
81
+ v0 = self.v_pos[i0, :]
82
+ v1 = self.v_pos[i1, :]
83
+ v2 = self.v_pos[i2, :]
84
+
85
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
86
+
87
+ # Splat face normals to vertices
88
+ v_nrm = torch.zeros_like(self.v_pos)
89
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
90
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
91
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
92
+
93
+ # Normalize, replace zero (degenerated) normals with some default value
94
+ v_nrm = torch.where(
95
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
96
+ )
97
+ v_nrm = F.normalize(v_nrm, dim=1)
98
+
99
+ if torch.is_anomaly_enabled():
100
+ assert torch.all(torch.isfinite(v_nrm))
101
+
102
+ return v_nrm
103
+
104
+ def _compute_vertex_tangent(self):
105
+ vn_idx = [None] * 3
106
+ pos = [None] * 3
107
+ tex = [None] * 3
108
+ for i in range(0, 3):
109
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
110
+ tex[i] = self.v_tex[self.t_pos_idx[:, i]]
111
+ # t_nrm_idx is always the same as t_pos_idx
112
+ vn_idx[i] = self.t_pos_idx[:, i]
113
+
114
+ tangents = torch.zeros_like(self.v_nrm)
115
+ tansum = torch.zeros_like(self.v_nrm)
116
+
117
+ # Compute tangent space for each triangle
118
+ duv1 = tex[1] - tex[0]
119
+ duv2 = tex[2] - tex[0]
120
+ dpos1 = pos[1] - pos[0]
121
+ dpos2 = pos[2] - pos[0]
122
+
123
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
124
+
125
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
126
+
127
+ # Avoid division by zero for degenerated texture coordinates
128
+ denom_safe = denom.clip(1e-6)
129
+ tang = tng_nom / denom_safe
130
+
131
+ # Update all 3 vertices
132
+ for i in range(0, 3):
133
+ idx = vn_idx[i][:, None].repeat(1, 3)
134
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
135
+ tansum.scatter_add_(
136
+ 0, idx, torch.ones_like(tang)
137
+ ) # tansum[n_i] = tansum[n_i] + 1
138
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
139
+ # triangles influence the tangent space more
140
+ tangents = tangents / tansum
141
+
142
+ # Normalize and make sure tangent is perpendicular to normal
143
+ tangents = F.normalize(tangents, dim=1)
144
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
145
+
146
+ if torch.is_anomaly_enabled():
147
+ assert torch.all(torch.isfinite(tangents))
148
+
149
+ return tangents
150
+
151
+ def quad_remesh(
152
+ self,
153
+ quad_vertex_count: int = -1,
154
+ quad_rosy: int = 4,
155
+ quad_crease_angle: float = -1.0,
156
+ quad_smooth_iter: int = 2,
157
+ quad_align_to_boundaries: bool = False,
158
+ ) -> Mesh:
159
+ if quad_vertex_count < 0:
160
+ quad_vertex_count = self.v_pos.shape[0]
161
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
162
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
163
+
164
+ new_vert, new_faces = pynanoinstantmeshes.remesh(
165
+ v_pos,
166
+ t_pos_idx,
167
+ quad_vertex_count // 4,
168
+ rosy=quad_rosy,
169
+ posy=4,
170
+ creaseAngle=quad_crease_angle,
171
+ align_to_boundaries=quad_align_to_boundaries,
172
+ smooth_iter=quad_smooth_iter,
173
+ deterministic=False,
174
+ )
175
+
176
+ # Briefly load in trimesh
177
+ mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
178
+
179
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
180
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
181
+
182
+ # Create new mesh
183
+ return Mesh(v_pos, t_pos_idx)
184
+
185
+ def triangle_remesh(
186
+ self,
187
+ triangle_average_edge_length_multiplier: Optional[float] = None,
188
+ triangle_remesh_steps: int = 10,
189
+ triangle_vertex_count=-1,
190
+ ):
191
+ if triangle_vertex_count > 0:
192
+ reduction = triangle_vertex_count / self.v_pos.shape[0]
193
+ print("Triangle reduction:", reduction)
194
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
195
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
196
+ if reduction > 1.0:
197
+ subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
198
+ print("Subdivide iters:", subdivide_iters)
199
+ v_pos, t_pos_idx = gpytoolbox.subdivide(
200
+ v_pos,
201
+ t_pos_idx,
202
+ iters=subdivide_iters,
203
+ )
204
+ reduction = triangle_vertex_count / v_pos.shape[0]
205
+
206
+ # Simplify
207
+ points_out, faces_out, _, _ = gpytoolbox.decimate(
208
+ v_pos,
209
+ t_pos_idx,
210
+ face_ratio=reduction,
211
+ )
212
+
213
+ # Convert back to torch
214
+ self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
215
+ self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
216
+ self._edges = None
217
+ triangle_average_edge_length_multiplier = None
218
+
219
+ edges = self.edges
220
+ if triangle_average_edge_length_multiplier is None:
221
+ h = None
222
+ else:
223
+ h = float(
224
+ torch.linalg.norm(
225
+ self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
226
+ )
227
+ .mean()
228
+ .item()
229
+ * triangle_average_edge_length_multiplier
230
+ )
231
+
232
+ # Convert to numpy
233
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
234
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
235
+
236
+ # Remesh
237
+ v_remesh, f_remesh = gpytoolbox.remesh_botsch(
238
+ v_pos,
239
+ t_pos_idx,
240
+ triangle_remesh_steps,
241
+ h,
242
+ )
243
+
244
+ # Convert back to torch
245
+ v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
246
+ t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
247
+
248
+ # Create new mesh
249
+ return Mesh(v_pos, t_pos_idx)
250
+
251
+ @torch.no_grad()
252
+ def unwrap_uv(
253
+ self,
254
+ island_padding: float = 0.02,
255
+ ) -> Mesh:
256
+ uv, indices = self.unwrapper(
257
+ self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
258
+ )
259
+
260
+ # Do store per vertex UVs.
261
+ # This means we need to duplicate some vertices at the seams
262
+ individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
263
+ individual_faces = torch.arange(
264
+ individual_vertices.shape[0],
265
+ device=individual_vertices.device,
266
+ dtype=self.t_pos_idx.dtype,
267
+ ).reshape(-1, 3)
268
+ uv_flat = uv[indices].reshape((-1, 2))
269
+ # uv_flat[:, 1] = 1 - uv_flat[:, 1]
270
+
271
+ self.v_pos = individual_vertices
272
+ self.t_pos_idx = individual_faces
273
+ self._v_tex = uv_flat
274
+ self._v_nrm = self._compute_vertex_normal()
275
+ self._v_tng = self._compute_vertex_tangent()
276
+
277
+ def _compute_edges(self):
278
+ # Compute edges
279
+ edges = torch.cat(
280
+ [
281
+ self.t_pos_idx[:, [0, 1]],
282
+ self.t_pos_idx[:, [1, 2]],
283
+ self.t_pos_idx[:, [2, 0]],
284
+ ],
285
+ dim=0,
286
+ )
287
+ edges = edges.sort()[0]
288
+ edges = torch.unique(edges, dim=0)
289
+ return edges
sf3d/models/network.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from jaxtyping import Float
9
+ from torch import Tensor
10
+ from torch.amp import custom_bwd, custom_fwd
11
+ from torch.autograd import Function
12
+
13
+ from sf3d.models.utils import BaseModule, normalize
14
+ from sf3d.utils import get_device
15
+
16
+
17
+ def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
18
+ def wrapper(fn):
19
+ if condition:
20
+ if len(kwargs) == 0:
21
+ return decorator_with_args
22
+ return decorator_with_args(*args, **kwargs)(fn)
23
+ else:
24
+ return fn
25
+
26
+ return wrapper
27
+
28
+
29
+ class PixelShuffleUpsampleNetwork(BaseModule):
30
+ @dataclass
31
+ class Config(BaseModule.Config):
32
+ in_channels: int = 1024
33
+ out_channels: int = 40
34
+ scale_factor: int = 4
35
+
36
+ conv_layers: int = 4
37
+ conv_kernel_size: int = 3
38
+
39
+ cfg: Config
40
+
41
+ def configure(self) -> None:
42
+ layers = []
43
+ output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
44
+
45
+ in_channels = self.cfg.in_channels
46
+ for i in range(self.cfg.conv_layers):
47
+ cur_out_channels = (
48
+ in_channels if i != self.cfg.conv_layers - 1 else output_channels
49
+ )
50
+ layers.append(
51
+ nn.Conv2d(
52
+ in_channels,
53
+ cur_out_channels,
54
+ self.cfg.conv_kernel_size,
55
+ padding=(self.cfg.conv_kernel_size - 1) // 2,
56
+ )
57
+ )
58
+ if i != self.cfg.conv_layers - 1:
59
+ layers.append(nn.ReLU(inplace=True))
60
+
61
+ layers.append(nn.PixelShuffle(self.cfg.scale_factor))
62
+
63
+ self.upsample = nn.Sequential(*layers)
64
+
65
+ def forward(
66
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
67
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
68
+ return rearrange(
69
+ self.upsample(
70
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
71
+ ),
72
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
73
+ Np=3,
74
+ )
75
+
76
+
77
+ class _TruncExp(Function): # pylint: disable=abstract-method
78
+ # Implementation from torch-ngp:
79
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
80
+ @staticmethod
81
+ @conditional_decorator(
82
+ custom_fwd,
83
+ "cuda" in get_device(),
84
+ cast_inputs=torch.float32,
85
+ device_type="cuda",
86
+ )
87
+ def forward(ctx, x): # pylint: disable=arguments-differ
88
+ ctx.save_for_backward(x)
89
+ return torch.exp(x)
90
+
91
+ @staticmethod
92
+ @conditional_decorator(custom_bwd, "cuda" in get_device())
93
+ def backward(ctx, g): # pylint: disable=arguments-differ
94
+ x = ctx.saved_tensors[0]
95
+ return g * torch.exp(torch.clamp(x, max=15))
96
+
97
+
98
+ trunc_exp = _TruncExp.apply
99
+
100
+
101
+ def get_activation(name) -> Callable:
102
+ if name is None:
103
+ return lambda x: x
104
+ name = name.lower()
105
+ if name == "none" or name == "linear" or name == "identity":
106
+ return lambda x: x
107
+ elif name == "lin2srgb":
108
+ return lambda x: torch.where(
109
+ x > 0.0031308,
110
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
111
+ 12.92 * x,
112
+ ).clamp(0.0, 1.0)
113
+ elif name == "exp":
114
+ return lambda x: torch.exp(x)
115
+ elif name == "shifted_exp":
116
+ return lambda x: torch.exp(x - 1.0)
117
+ elif name == "trunc_exp":
118
+ return trunc_exp
119
+ elif name == "shifted_trunc_exp":
120
+ return lambda x: trunc_exp(x - 1.0)
121
+ elif name == "sigmoid":
122
+ return lambda x: torch.sigmoid(x)
123
+ elif name == "tanh":
124
+ return lambda x: torch.tanh(x)
125
+ elif name == "shifted_softplus":
126
+ return lambda x: F.softplus(x - 1.0)
127
+ elif name == "scale_-11_01":
128
+ return lambda x: x * 0.5 + 0.5
129
+ elif name == "negative":
130
+ return lambda x: -x
131
+ elif name == "normalize_channel_last":
132
+ return lambda x: normalize(x)
133
+ elif name == "normalize_channel_first":
134
+ return lambda x: normalize(x, dim=1)
135
+ else:
136
+ try:
137
+ return getattr(F, name)
138
+ except AttributeError:
139
+ raise ValueError(f"Unknown activation function: {name}")
140
+
141
+
142
+ @dataclass
143
+ class HeadSpec:
144
+ name: str
145
+ out_channels: int
146
+ n_hidden_layers: int
147
+ output_activation: Optional[str] = None
148
+ out_bias: float = 0.0
149
+
150
+
151
+ class MaterialMLP(BaseModule):
152
+ @dataclass
153
+ class Config(BaseModule.Config):
154
+ in_channels: int = 120
155
+ n_neurons: int = 64
156
+ activation: str = "silu"
157
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
158
+
159
+ cfg: Config
160
+
161
+ def configure(self) -> None:
162
+ assert len(self.cfg.heads) > 0
163
+ heads = {}
164
+ for head in self.cfg.heads:
165
+ head_layers = []
166
+ for i in range(head.n_hidden_layers):
167
+ head_layers += [
168
+ nn.Linear(
169
+ self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
170
+ self.cfg.n_neurons,
171
+ ),
172
+ self.make_activation(self.cfg.activation),
173
+ ]
174
+ head_layers += [
175
+ nn.Linear(
176
+ self.cfg.n_neurons,
177
+ head.out_channels,
178
+ ),
179
+ ]
180
+ heads[head.name] = nn.Sequential(*head_layers)
181
+ self.heads = nn.ModuleDict(heads)
182
+
183
+ def make_activation(self, activation):
184
+ if activation == "relu":
185
+ return nn.ReLU(inplace=True)
186
+ elif activation == "silu":
187
+ return nn.SiLU(inplace=True)
188
+ else:
189
+ raise NotImplementedError
190
+
191
+ def keys(self):
192
+ return self.heads.keys()
193
+
194
+ def forward(
195
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
196
+ ):
197
+ if include is not None and exclude is not None:
198
+ raise ValueError("Cannot specify both include and exclude.")
199
+ if include is not None:
200
+ heads = [h for h in self.cfg.heads if h.name in include]
201
+ elif exclude is not None:
202
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
203
+ else:
204
+ heads = self.cfg.heads
205
+
206
+ out = {
207
+ head.name: get_activation(head.output_activation)(
208
+ self.heads[head.name](x) + head.out_bias
209
+ )
210
+ for head in heads
211
+ }
212
+
213
+ return out
sf3d/models/tokenizers/dinov2.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DINOv2 model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BackboneOutput,
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
36
+ from transformers.pytorch_utils import (
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.utils.backbone_utils import BackboneMixin
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ # General docstring
52
+ _CONFIG_FOR_DOC = "Dinov2Config"
53
+
54
+ # Base docstring
55
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
60
+
61
+
62
+ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "facebook/dinov2-base",
64
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
65
+ ]
66
+
67
+
68
+ class Dinov2Embeddings(nn.Module):
69
+ """
70
+ Construct the CLS token, mask token, position and patch embeddings.
71
+ """
72
+
73
+ def __init__(self, config: Dinov2Config) -> None:
74
+ super().__init__()
75
+
76
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
77
+ # register as mask token as it's not used in optimization
78
+ # to avoid the use of find_unused_parameters_true
79
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
80
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
81
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
82
+ num_patches = self.patch_embeddings.num_patches
83
+ self.position_embeddings = nn.Parameter(
84
+ torch.randn(1, num_patches + 1, config.hidden_size)
85
+ )
86
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
87
+ self.config = config
88
+
89
+ def interpolate_pos_encoding(
90
+ self, embeddings: torch.Tensor, height: int, width: int
91
+ ) -> torch.Tensor:
92
+ """
93
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
94
+ resolution images.
95
+
96
+ Source:
97
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
98
+ """
99
+
100
+ num_patches = embeddings.shape[1] - 1
101
+ num_positions = self.position_embeddings.shape[1] - 1
102
+ if num_patches == num_positions and height == width:
103
+ return self.position_embeddings
104
+ class_pos_embed = self.position_embeddings[:, 0]
105
+ patch_pos_embed = self.position_embeddings[:, 1:]
106
+ dim = embeddings.shape[-1]
107
+ height = height // self.config.patch_size
108
+ width = width // self.config.patch_size
109
+ # we add a small number to avoid floating point error in the interpolation
110
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
111
+ height, width = height + 0.1, width + 0.1
112
+ patch_pos_embed = patch_pos_embed.reshape(
113
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
114
+ )
115
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
116
+ patch_pos_embed = nn.functional.interpolate(
117
+ patch_pos_embed,
118
+ scale_factor=(
119
+ height / math.sqrt(num_positions),
120
+ width / math.sqrt(num_positions),
121
+ ),
122
+ mode="bicubic",
123
+ align_corners=False,
124
+ )
125
+ if (
126
+ int(height) != patch_pos_embed.shape[-2]
127
+ or int(width) != patch_pos_embed.shape[-1]
128
+ ):
129
+ raise ValueError(
130
+ "Width or height does not match with the interpolated position embeddings"
131
+ )
132
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
133
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
134
+
135
+ def forward(
136
+ self,
137
+ pixel_values: torch.Tensor,
138
+ bool_masked_pos: Optional[torch.Tensor] = None,
139
+ ) -> torch.Tensor:
140
+ batch_size, _, height, width = pixel_values.shape
141
+ patch_embeddings = self.patch_embeddings(pixel_values)
142
+ embeddings = patch_embeddings
143
+
144
+ if bool_masked_pos is not None:
145
+ embeddings = torch.where(
146
+ bool_masked_pos.unsqueeze(-1),
147
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
148
+ embeddings,
149
+ )
150
+
151
+ # add the [CLS] token to the embedded patch tokens
152
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
153
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
154
+
155
+ # add positional encoding to each token
156
+ embeddings = embeddings + self.interpolate_pos_encoding(
157
+ embeddings, height, width
158
+ )
159
+
160
+ embeddings = self.dropout(embeddings)
161
+
162
+ return embeddings
163
+
164
+
165
+ class Dinov2PatchEmbeddings(nn.Module):
166
+ """
167
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
168
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
169
+ Transformer.
170
+ """
171
+
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ image_size, patch_size = config.image_size, config.patch_size
175
+ num_channels, hidden_size = config.num_channels, config.hidden_size
176
+
177
+ image_size = (
178
+ image_size
179
+ if isinstance(image_size, collections.abc.Iterable)
180
+ else (image_size, image_size)
181
+ )
182
+ patch_size = (
183
+ patch_size
184
+ if isinstance(patch_size, collections.abc.Iterable)
185
+ else (patch_size, patch_size)
186
+ )
187
+ num_patches = (image_size[1] // patch_size[1]) * (
188
+ image_size[0] // patch_size[0]
189
+ )
190
+ self.image_size = image_size
191
+ self.patch_size = patch_size
192
+ self.num_channels = num_channels
193
+ self.num_patches = num_patches
194
+
195
+ self.projection = nn.Conv2d(
196
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
197
+ )
198
+
199
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
200
+ """
201
+ num_channels = pixel_values.shape[1]
202
+ if num_channels != self.num_channels:
203
+ raise ValueError(
204
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
205
+ f" Expected {self.num_channels} but got {num_channels}."
206
+ )
207
+ """
208
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
209
+ return embeddings
210
+
211
+
212
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
213
+ class Dinov2SelfAttention(nn.Module):
214
+ def __init__(self, config: Dinov2Config) -> None:
215
+ super().__init__()
216
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
217
+ config, "embedding_size"
218
+ ):
219
+ raise ValueError(
220
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
221
+ f"heads {config.num_attention_heads}."
222
+ )
223
+
224
+ self.num_attention_heads = config.num_attention_heads
225
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
226
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
227
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
228
+
229
+ self.query = nn.Linear(
230
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
231
+ )
232
+ self.key = nn.Linear(
233
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
234
+ )
235
+ self.value = nn.Linear(
236
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
237
+ )
238
+
239
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
240
+
241
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
242
+ new_x_shape = x.size()[:-1] + (
243
+ self.num_attention_heads,
244
+ self.attention_head_size,
245
+ )
246
+ x = x.view(new_x_shape)
247
+ return x.permute(0, 2, 1, 3)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states,
252
+ head_mask: Optional[torch.Tensor] = None,
253
+ output_attentions: bool = False,
254
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
255
+ mixed_query_layer = self.query(hidden_states)
256
+
257
+ if hasattr(F, "scaled_dot_product_attention"):
258
+ assert head_mask is None and not output_attentions
259
+ new_size = hidden_states.size()[:-1] + (
260
+ self.num_attention_heads,
261
+ self.attention_head_size,
262
+ )
263
+ key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
264
+ value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
265
+ query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
266
+ context_layer = F.scaled_dot_product_attention(
267
+ query_layer,
268
+ key_layer,
269
+ value_layer,
270
+ dropout_p=self.attention_probs_dropout_prob,
271
+ is_causal=False,
272
+ )
273
+ context_layer = context_layer.transpose(1, 2).reshape(
274
+ *hidden_states.size()[:-1], -1
275
+ )
276
+ else:
277
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
278
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
279
+ query_layer = self.transpose_for_scores(mixed_query_layer)
280
+
281
+ # Take the dot product between "query" and "key" to get the raw attention scores.
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+
284
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
285
+
286
+ # Normalize the attention scores to probabilities.
287
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
288
+
289
+ # This is actually dropping out entire tokens to attend to, which might
290
+ # seem a bit unusual, but is taken from the original Transformer paper.
291
+ attention_probs = self.dropout(attention_probs)
292
+
293
+ # Mask heads if we want to
294
+ if head_mask is not None:
295
+ attention_probs = attention_probs * head_mask
296
+
297
+ context_layer = torch.matmul(attention_probs, value_layer)
298
+
299
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
301
+ context_layer = context_layer.view(new_context_layer_shape)
302
+
303
+ outputs = (
304
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
305
+ )
306
+
307
+ return outputs
308
+
309
+
310
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
311
+ class Dinov2SelfOutput(nn.Module):
312
+ """
313
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
314
+ layernorm applied before each block.
315
+ """
316
+
317
+ def __init__(self, config: Dinov2Config) -> None:
318
+ super().__init__()
319
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
320
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
+
322
+ def forward(
323
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
324
+ ) -> torch.Tensor:
325
+ hidden_states = self.dense(hidden_states)
326
+ hidden_states = self.dropout(hidden_states)
327
+
328
+ return hidden_states
329
+
330
+
331
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
332
+ class Dinov2Attention(nn.Module):
333
+ def __init__(self, config: Dinov2Config) -> None:
334
+ super().__init__()
335
+ self.attention = Dinov2SelfAttention(config)
336
+ self.output = Dinov2SelfOutput(config)
337
+ self.pruned_heads = set()
338
+
339
+ def prune_heads(self, heads: Set[int]) -> None:
340
+ if len(heads) == 0:
341
+ return
342
+ heads, index = find_pruneable_heads_and_indices(
343
+ heads,
344
+ self.attention.num_attention_heads,
345
+ self.attention.attention_head_size,
346
+ self.pruned_heads,
347
+ )
348
+
349
+ # Prune linear layers
350
+ self.attention.query = prune_linear_layer(self.attention.query, index)
351
+ self.attention.key = prune_linear_layer(self.attention.key, index)
352
+ self.attention.value = prune_linear_layer(self.attention.value, index)
353
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
354
+
355
+ # Update hyper params and store pruned heads
356
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
357
+ heads
358
+ )
359
+ self.attention.all_head_size = (
360
+ self.attention.attention_head_size * self.attention.num_attention_heads
361
+ )
362
+ self.pruned_heads = self.pruned_heads.union(heads)
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ head_mask: Optional[torch.Tensor] = None,
368
+ output_attentions: bool = False,
369
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
370
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
371
+
372
+ attention_output = self.output(self_outputs[0], hidden_states)
373
+
374
+ outputs = (attention_output,) + self_outputs[
375
+ 1:
376
+ ] # add attentions if we output them
377
+ return outputs
378
+
379
+
380
+ class Dinov2LayerScale(nn.Module):
381
+ def __init__(self, config) -> None:
382
+ super().__init__()
383
+ self.lambda1 = nn.Parameter(
384
+ config.layerscale_value * torch.ones(config.hidden_size)
385
+ )
386
+
387
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
388
+ return hidden_state * self.lambda1
389
+
390
+
391
+ # Copied from transformers.models.beit.modeling_beit.drop_path
392
+ def drop_path(
393
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
394
+ ) -> torch.Tensor:
395
+ """
396
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
397
+
398
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
399
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
400
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
401
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
402
+ argument.
403
+ """
404
+ if drop_prob == 0.0 or not training:
405
+ return input
406
+ keep_prob = 1 - drop_prob
407
+ shape = (input.shape[0],) + (1,) * (
408
+ input.ndim - 1
409
+ ) # work with diff dim tensors, not just 2D ConvNets
410
+ random_tensor = keep_prob + torch.rand(
411
+ shape, dtype=input.dtype, device=input.device
412
+ )
413
+ random_tensor.floor_() # binarize
414
+ output = input.div(keep_prob) * random_tensor
415
+ return output
416
+
417
+
418
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
419
+ class Dinov2DropPath(nn.Module):
420
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
421
+
422
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
423
+ super().__init__()
424
+ self.drop_prob = drop_prob
425
+
426
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
427
+ return drop_path(hidden_states, self.drop_prob, self.training)
428
+
429
+ def extra_repr(self) -> str:
430
+ return "p={}".format(self.drop_prob)
431
+
432
+
433
+ class Dinov2MLP(nn.Module):
434
+ def __init__(self, config) -> None:
435
+ super().__init__()
436
+ in_features = out_features = config.hidden_size
437
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
438
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
439
+ if isinstance(config.hidden_act, str):
440
+ self.activation = ACT2FN[config.hidden_act]
441
+ else:
442
+ self.activation = config.hidden_act
443
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
444
+
445
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
446
+ hidden_state = self.fc1(hidden_state)
447
+ hidden_state = self.activation(hidden_state)
448
+ hidden_state = self.fc2(hidden_state)
449
+ return hidden_state
450
+
451
+
452
+ class Dinov2SwiGLUFFN(nn.Module):
453
+ def __init__(self, config) -> None:
454
+ super().__init__()
455
+ in_features = out_features = config.hidden_size
456
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
457
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
458
+
459
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
460
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
461
+
462
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
463
+ hidden_state = self.weights_in(hidden_state)
464
+ x1, x2 = hidden_state.chunk(2, dim=-1)
465
+ hidden = nn.functional.silu(x1) * x2
466
+ return self.weights_out(hidden)
467
+
468
+
469
+ class Dinov2Layer(nn.Module):
470
+ """This corresponds to the Block class in the original implementation."""
471
+
472
+ def __init__(self, config: Dinov2Config) -> None:
473
+ super().__init__()
474
+
475
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+ self.norm1_modulation = None
477
+ self.attention = Dinov2Attention(config)
478
+ self.layer_scale1 = Dinov2LayerScale(config)
479
+ self.drop_path1 = (
480
+ Dinov2DropPath(config.drop_path_rate)
481
+ if config.drop_path_rate > 0.0
482
+ else nn.Identity()
483
+ )
484
+
485
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
486
+ self.norm2_modulation = None
487
+
488
+ if config.use_swiglu_ffn:
489
+ self.mlp = Dinov2SwiGLUFFN(config)
490
+ else:
491
+ self.mlp = Dinov2MLP(config)
492
+ self.layer_scale2 = Dinov2LayerScale(config)
493
+ self.drop_path2 = (
494
+ Dinov2DropPath(config.drop_path_rate)
495
+ if config.drop_path_rate > 0.0
496
+ else nn.Identity()
497
+ )
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ head_mask: Optional[torch.Tensor] = None,
503
+ modulation_cond: Optional[torch.Tensor] = None,
504
+ output_attentions: bool = False,
505
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
506
+ hidden_states_norm = self.norm1(hidden_states)
507
+ if self.norm1_modulation is not None:
508
+ assert modulation_cond is not None
509
+ hidden_states_norm = self.norm1_modulation(
510
+ hidden_states_norm, modulation_cond
511
+ )
512
+ self_attention_outputs = self.attention(
513
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ )
517
+ attention_output = self_attention_outputs[0]
518
+
519
+ attention_output = self.layer_scale1(attention_output)
520
+ outputs = self_attention_outputs[
521
+ 1:
522
+ ] # add self attentions if we output attention weights
523
+
524
+ # first residual connection
525
+ hidden_states = attention_output + hidden_states
526
+
527
+ # in Dinov2, layernorm is also applied after self-attention
528
+ layer_output = self.norm2(hidden_states)
529
+ if self.norm2_modulation is not None:
530
+ assert modulation_cond is not None
531
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
532
+ layer_output = self.mlp(layer_output)
533
+ layer_output = self.layer_scale2(layer_output)
534
+
535
+ # second residual connection
536
+ layer_output = layer_output + hidden_states
537
+
538
+ outputs = (layer_output,) + outputs
539
+
540
+ return outputs
541
+
542
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
543
+ self.norm1_modulation = norm1_mod
544
+ self.norm2_modulation = norm2_mod
545
+
546
+
547
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
548
+ class Dinov2Encoder(nn.Module):
549
+ def __init__(self, config: Dinov2Config) -> None:
550
+ super().__init__()
551
+ self.config = config
552
+ self.layer = nn.ModuleList(
553
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
554
+ )
555
+ self.gradient_checkpointing = False
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.Tensor,
560
+ head_mask: Optional[torch.Tensor] = None,
561
+ modulation_cond: Optional[torch.Tensor] = None,
562
+ output_attentions: bool = False,
563
+ output_hidden_states: bool = False,
564
+ return_dict: bool = True,
565
+ ) -> Union[tuple, BaseModelOutput]:
566
+ all_hidden_states = () if output_hidden_states else None
567
+ all_self_attentions = () if output_attentions else None
568
+
569
+ for i, layer_module in enumerate(self.layer):
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ layer_head_mask = head_mask[i] if head_mask is not None else None
574
+
575
+ if self.gradient_checkpointing and self.training:
576
+
577
+ def create_custom_forward(module):
578
+ def custom_forward(*inputs):
579
+ return module(*inputs, output_attentions)
580
+
581
+ return custom_forward
582
+
583
+ layer_outputs = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(layer_module),
585
+ hidden_states,
586
+ layer_head_mask,
587
+ modulation_cond,
588
+ use_reentrant=False,
589
+ )
590
+ else:
591
+ layer_outputs = layer_module(
592
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
593
+ )
594
+
595
+ hidden_states = layer_outputs[0]
596
+
597
+ if output_attentions:
598
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
599
+
600
+ if output_hidden_states:
601
+ all_hidden_states = all_hidden_states + (hidden_states,)
602
+
603
+ if not return_dict:
604
+ return tuple(
605
+ v
606
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
607
+ if v is not None
608
+ )
609
+ return BaseModelOutput(
610
+ last_hidden_state=hidden_states,
611
+ hidden_states=all_hidden_states,
612
+ attentions=all_self_attentions,
613
+ )
614
+
615
+
616
+ class Dinov2PreTrainedModel(PreTrainedModel):
617
+ """
618
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
619
+ models.
620
+ """
621
+
622
+ config_class = Dinov2Config
623
+ base_model_prefix = "dinov2"
624
+ main_input_name = "pixel_values"
625
+ supports_gradient_checkpointing = True
626
+
627
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
628
+ """Initialize the weights"""
629
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
630
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
631
+ # `trunc_normal_cpu` not implemented in `half` issues
632
+ module.weight.data = nn.init.trunc_normal_(
633
+ module.weight.data.to(torch.float32),
634
+ mean=0.0,
635
+ std=self.config.initializer_range,
636
+ ).to(module.weight.dtype)
637
+ if module.bias is not None:
638
+ module.bias.data.zero_()
639
+ elif isinstance(module, nn.LayerNorm):
640
+ module.bias.data.zero_()
641
+ module.weight.data.fill_(1.0)
642
+ elif isinstance(module, Dinov2Embeddings):
643
+ module.position_embeddings.data = nn.init.trunc_normal_(
644
+ module.position_embeddings.data.to(torch.float32),
645
+ mean=0.0,
646
+ std=self.config.initializer_range,
647
+ ).to(module.position_embeddings.dtype)
648
+
649
+ module.cls_token.data = nn.init.trunc_normal_(
650
+ module.cls_token.data.to(torch.float32),
651
+ mean=0.0,
652
+ std=self.config.initializer_range,
653
+ ).to(module.cls_token.dtype)
654
+
655
+ def _set_gradient_checkpointing(
656
+ self, module: Dinov2Encoder, value: bool = False
657
+ ) -> None:
658
+ if isinstance(module, Dinov2Encoder):
659
+ module.gradient_checkpointing = value
660
+
661
+
662
+ DINOV2_START_DOCSTRING = r"""
663
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
664
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
665
+ behavior.
666
+
667
+ Parameters:
668
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
669
+ Initializing with a config file does not load the weights associated with the model, only the
670
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
+ """
672
+
673
+ DINOV2_BASE_INPUTS_DOCSTRING = r"""
674
+ Args:
675
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
676
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
677
+ [`BitImageProcessor.preprocess`] for details.
678
+
679
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
680
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
681
+ pre-training.
682
+
683
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
684
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
685
+
686
+ - 1 indicates the head is **not masked**,
687
+ - 0 indicates the head is **masked**.
688
+
689
+ output_attentions (`bool`, *optional*):
690
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
691
+ tensors for more detail.
692
+ output_hidden_states (`bool`, *optional*):
693
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
694
+ more detail.
695
+ return_dict (`bool`, *optional*):
696
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
697
+ """
698
+
699
+ DINOV2_INPUTS_DOCSTRING = r"""
700
+ Args:
701
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
702
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
703
+ [`BitImageProcessor.preprocess`] for details.
704
+
705
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
706
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
707
+
708
+ - 1 indicates the head is **not masked**,
709
+ - 0 indicates the head is **masked**.
710
+
711
+ output_attentions (`bool`, *optional*):
712
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
713
+ tensors for more detail.
714
+ output_hidden_states (`bool`, *optional*):
715
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
716
+ more detail.
717
+ return_dict (`bool`, *optional*):
718
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
719
+ """
720
+
721
+
722
+ @dataclass
723
+ class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
724
+ patch_embeddings: Optional[torch.FloatTensor] = None
725
+
726
+
727
+ @add_start_docstrings(
728
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
729
+ DINOV2_START_DOCSTRING,
730
+ )
731
+ class Dinov2Model(Dinov2PreTrainedModel):
732
+ def __init__(self, config: Dinov2Config):
733
+ super().__init__(config)
734
+ self.config = config
735
+
736
+ self.embeddings = Dinov2Embeddings(config)
737
+ self.encoder = Dinov2Encoder(config)
738
+
739
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
745
+ return self.embeddings.patch_embeddings
746
+
747
+ def expand_input_channels(self, extra_input_channels: int) -> None:
748
+ if extra_input_channels == 0:
749
+ return
750
+ conv_old = self.embeddings.patch_embeddings.projection
751
+ conv_new = nn.Conv2d(
752
+ self.config.num_channels + extra_input_channels,
753
+ self.config.hidden_size,
754
+ kernel_size=self.config.patch_size,
755
+ stride=self.config.patch_size,
756
+ ).to(self.device)
757
+ with torch.no_grad():
758
+ conv_new.weight[:, :3] = conv_old.weight
759
+ conv_new.bias = conv_old.bias
760
+ self.embeddings.patch_embeddings.projection = conv_new
761
+ del conv_old
762
+
763
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
764
+ """
765
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
766
+ class PreTrainedModel
767
+ """
768
+ for layer, heads in heads_to_prune.items():
769
+ self.encoder.layer[layer].attention.prune_heads(heads)
770
+
771
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
772
+ @add_code_sample_docstrings(
773
+ checkpoint=_CHECKPOINT_FOR_DOC,
774
+ output_type=BaseModelOutputWithPooling,
775
+ config_class=_CONFIG_FOR_DOC,
776
+ modality="vision",
777
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
778
+ )
779
+ def forward(
780
+ self,
781
+ pixel_values: Optional[torch.Tensor] = None,
782
+ bool_masked_pos: Optional[torch.Tensor] = None,
783
+ head_mask: Optional[torch.Tensor] = None,
784
+ modulation_cond: Optional[torch.Tensor] = None,
785
+ output_attentions: Optional[bool] = None,
786
+ output_hidden_states: Optional[bool] = None,
787
+ return_dict: Optional[bool] = None,
788
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
789
+ output_attentions = (
790
+ output_attentions
791
+ if output_attentions is not None
792
+ else self.config.output_attentions
793
+ )
794
+ output_hidden_states = (
795
+ output_hidden_states
796
+ if output_hidden_states is not None
797
+ else self.config.output_hidden_states
798
+ )
799
+ return_dict = (
800
+ return_dict if return_dict is not None else self.config.use_return_dict
801
+ )
802
+
803
+ if pixel_values is None:
804
+ raise ValueError("You have to specify pixel_values")
805
+
806
+ # Prepare head mask if needed
807
+ # 1.0 in head_mask indicate we keep the head
808
+ # attention_probs has shape bsz x n_heads x N x N
809
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
810
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
811
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
812
+
813
+ embedding_output = self.embeddings(
814
+ pixel_values, bool_masked_pos=bool_masked_pos
815
+ )
816
+
817
+ encoder_outputs = self.encoder(
818
+ embedding_output,
819
+ head_mask=head_mask,
820
+ modulation_cond=modulation_cond,
821
+ output_attentions=output_attentions,
822
+ output_hidden_states=output_hidden_states,
823
+ return_dict=return_dict,
824
+ )
825
+ sequence_output = encoder_outputs[0]
826
+ sequence_output = self.layernorm(sequence_output)
827
+ pooled_output = sequence_output[:, 0, :]
828
+
829
+ if not return_dict:
830
+ head_outputs = (sequence_output, pooled_output)
831
+ return head_outputs + encoder_outputs[1:]
832
+
833
+ return CustomBaseModelOutputWithPooling(
834
+ last_hidden_state=sequence_output,
835
+ pooler_output=pooled_output,
836
+ hidden_states=encoder_outputs.hidden_states,
837
+ attentions=encoder_outputs.attentions,
838
+ patch_embeddings=embedding_output,
839
+ )
840
+
841
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
842
+ self._set_gradient_checkpointing(self.encoder, value)
843
+
844
+
845
+ @add_start_docstrings(
846
+ """
847
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
848
+ of the [CLS] token) e.g. for ImageNet.
849
+ """,
850
+ DINOV2_START_DOCSTRING,
851
+ )
852
+ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
853
+ def __init__(self, config: Dinov2Config) -> None:
854
+ super().__init__(config)
855
+
856
+ self.num_labels = config.num_labels
857
+ self.dinov2 = Dinov2Model(config)
858
+
859
+ # Classifier head
860
+ self.classifier = (
861
+ nn.Linear(config.hidden_size * 2, config.num_labels)
862
+ if config.num_labels > 0
863
+ else nn.Identity()
864
+ )
865
+
866
+ # Initialize weights and apply final processing
867
+ self.post_init()
868
+
869
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
870
+ @add_code_sample_docstrings(
871
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
872
+ output_type=ImageClassifierOutput,
873
+ config_class=_CONFIG_FOR_DOC,
874
+ )
875
+ def forward(
876
+ self,
877
+ pixel_values: Optional[torch.Tensor] = None,
878
+ head_mask: Optional[torch.Tensor] = None,
879
+ labels: Optional[torch.Tensor] = None,
880
+ output_attentions: Optional[bool] = None,
881
+ output_hidden_states: Optional[bool] = None,
882
+ return_dict: Optional[bool] = None,
883
+ ) -> Union[tuple, ImageClassifierOutput]:
884
+ r"""
885
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
886
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
887
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
888
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
889
+ """
890
+ return_dict = (
891
+ return_dict if return_dict is not None else self.config.use_return_dict
892
+ )
893
+
894
+ outputs = self.dinov2(
895
+ pixel_values,
896
+ head_mask=head_mask,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ )
901
+
902
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
903
+
904
+ cls_token = sequence_output[:, 0]
905
+ patch_tokens = sequence_output[:, 1:]
906
+
907
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
908
+
909
+ logits = self.classifier(linear_input)
910
+
911
+ loss = None
912
+ if labels is not None:
913
+ # move labels to correct device to enable model parallelism
914
+ labels = labels.to(logits.device)
915
+ if self.config.problem_type is None:
916
+ if self.num_labels == 1:
917
+ self.config.problem_type = "regression"
918
+ elif self.num_labels > 1 and (
919
+ labels.dtype == torch.long or labels.dtype == torch.int
920
+ ):
921
+ self.config.problem_type = "single_label_classification"
922
+ else:
923
+ self.config.problem_type = "multi_label_classification"
924
+
925
+ if self.config.problem_type == "regression":
926
+ loss_fct = MSELoss()
927
+ if self.num_labels == 1:
928
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
929
+ else:
930
+ loss = loss_fct(logits, labels)
931
+ elif self.config.problem_type == "single_label_classification":
932
+ loss_fct = CrossEntropyLoss()
933
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
934
+ elif self.config.problem_type == "multi_label_classification":
935
+ loss_fct = BCEWithLogitsLoss()
936
+ loss = loss_fct(logits, labels)
937
+
938
+ if not return_dict:
939
+ output = (logits,) + outputs[2:]
940
+ return ((loss,) + output) if loss is not None else output
941
+
942
+ return ImageClassifierOutput(
943
+ loss=loss,
944
+ logits=logits,
945
+ hidden_states=outputs.hidden_states,
946
+ attentions=outputs.attentions,
947
+ )
948
+
949
+
950
+ @add_start_docstrings(
951
+ """
952
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
953
+ """,
954
+ DINOV2_START_DOCSTRING,
955
+ )
956
+ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+ super()._init_backbone(config)
960
+
961
+ self.num_features = [
962
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
963
+ ]
964
+ self.embeddings = Dinov2Embeddings(config)
965
+ self.encoder = Dinov2Encoder(config)
966
+
967
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+
972
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
973
+ return self.embeddings.patch_embeddings
974
+
975
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
976
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
977
+ def forward(
978
+ self,
979
+ pixel_values: torch.Tensor,
980
+ output_hidden_states: Optional[bool] = None,
981
+ output_attentions: Optional[bool] = None,
982
+ return_dict: Optional[bool] = None,
983
+ ) -> BackboneOutput:
984
+ """
985
+ Returns:
986
+
987
+ Examples:
988
+
989
+ ```python
990
+ >>> from transformers import AutoImageProcessor, AutoBackbone
991
+ >>> import torch
992
+ >>> from PIL import Image
993
+ >>> import requests
994
+
995
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
996
+ >>> image = Image.open(requests.get(url, stream=True).raw)
997
+
998
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
999
+ >>> model = AutoBackbone.from_pretrained(
1000
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1001
+ ... )
1002
+
1003
+ >>> inputs = processor(image, return_tensors="pt")
1004
+
1005
+ >>> outputs = model(**inputs)
1006
+ >>> feature_maps = outputs.feature_maps
1007
+ >>> list(feature_maps[-1].shape)
1008
+ [1, 768, 16, 16]
1009
+ ```"""
1010
+ return_dict = (
1011
+ return_dict if return_dict is not None else self.config.use_return_dict
1012
+ )
1013
+ output_hidden_states = (
1014
+ output_hidden_states
1015
+ if output_hidden_states is not None
1016
+ else self.config.output_hidden_states
1017
+ )
1018
+ output_attentions = (
1019
+ output_attentions
1020
+ if output_attentions is not None
1021
+ else self.config.output_attentions
1022
+ )
1023
+
1024
+ embedding_output = self.embeddings(pixel_values)
1025
+
1026
+ outputs = self.encoder(
1027
+ embedding_output,
1028
+ output_hidden_states=True,
1029
+ output_attentions=output_attentions,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1034
+
1035
+ feature_maps = ()
1036
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1037
+ if stage in self.out_features:
1038
+ if self.config.apply_layernorm:
1039
+ hidden_state = self.layernorm(hidden_state)
1040
+ if self.config.reshape_hidden_states:
1041
+ batch_size, _, height, width = pixel_values.shape
1042
+ patch_size = self.config.patch_size
1043
+ hidden_state = hidden_state[:, 1:, :].reshape(
1044
+ batch_size, width // patch_size, height // patch_size, -1
1045
+ )
1046
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1047
+ feature_maps += (hidden_state,)
1048
+
1049
+ if not return_dict:
1050
+ if output_hidden_states:
1051
+ output = (feature_maps,) + outputs[1:]
1052
+ else:
1053
+ output = (feature_maps,) + outputs[2:]
1054
+ return output
1055
+
1056
+ return BackboneOutput(
1057
+ feature_maps=feature_maps,
1058
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1059
+ attentions=outputs.attentions if output_attentions else None,
1060
+ )
1061
+
1062
+
1063
+ class CustomPatchEmbeddings(nn.Module):
1064
+ """
1065
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
1066
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
1067
+ Transformer.
1068
+ """
1069
+
1070
+ def __init__(
1071
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1072
+ ):
1073
+ super().__init__()
1074
+
1075
+ image_size = (
1076
+ image_size
1077
+ if isinstance(image_size, collections.abc.Iterable)
1078
+ else (image_size, image_size)
1079
+ )
1080
+ patch_size = (
1081
+ patch_size
1082
+ if isinstance(patch_size, collections.abc.Iterable)
1083
+ else (patch_size, patch_size)
1084
+ )
1085
+ num_patches = (image_size[1] // patch_size[1]) * (
1086
+ image_size[0] // patch_size[0]
1087
+ )
1088
+ self.image_size = image_size
1089
+ self.patch_size = patch_size
1090
+ self.num_channels = num_channels
1091
+ self.num_patches = num_patches
1092
+
1093
+ self.projection = nn.Conv2d(
1094
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
1095
+ )
1096
+
1097
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1098
+ num_channels = pixel_values.shape[1]
1099
+ if num_channels != self.num_channels:
1100
+ raise ValueError(
1101
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
1102
+ f" Expected {self.num_channels} but got {num_channels}."
1103
+ )
1104
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
1105
+ return embeddings
1106
+
1107
+
1108
+ class CustomEmbeddings(nn.Module):
1109
+ """
1110
+ Construct the CLS token, mask token, position and patch embeddings.
1111
+ """
1112
+
1113
+ def __init__(
1114
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1115
+ ) -> None:
1116
+ super().__init__()
1117
+
1118
+ self.image_size = image_size
1119
+ self.patch_size = patch_size
1120
+ self.num_channels = num_channels
1121
+ self.hidden_size = hidden_size
1122
+
1123
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
1124
+
1125
+ self.patch_embeddings = CustomPatchEmbeddings(
1126
+ image_size, patch_size, num_channels, hidden_size
1127
+ )
1128
+ num_patches = self.patch_embeddings.num_patches
1129
+ self.position_embeddings = nn.Parameter(
1130
+ torch.randn(1, num_patches + 1, self.hidden_size)
1131
+ )
1132
+
1133
+ def interpolate_pos_encoding(
1134
+ self, embeddings: torch.Tensor, height: int, width: int
1135
+ ) -> torch.Tensor:
1136
+ """
1137
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
1138
+ resolution images.
1139
+
1140
+ Source:
1141
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1142
+ """
1143
+
1144
+ num_patches = embeddings.shape[1] - 1
1145
+ num_positions = self.position_embeddings.shape[1] - 1
1146
+ if num_patches == num_positions and height == width:
1147
+ return self.position_embeddings
1148
+ class_pos_embed = self.position_embeddings[:, 0]
1149
+ patch_pos_embed = self.position_embeddings[:, 1:]
1150
+ dim = embeddings.shape[-1]
1151
+ height = height // self.patch_size
1152
+ width = width // self.patch_size
1153
+ # we add a small number to avoid floating point error in the interpolation
1154
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
1155
+ height, width = height + 0.1, width + 0.1
1156
+ patch_pos_embed = patch_pos_embed.reshape(
1157
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
1158
+ )
1159
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
1160
+ patch_pos_embed = nn.functional.interpolate(
1161
+ patch_pos_embed,
1162
+ scale_factor=(
1163
+ height / math.sqrt(num_positions),
1164
+ width / math.sqrt(num_positions),
1165
+ ),
1166
+ mode="bicubic",
1167
+ align_corners=False,
1168
+ )
1169
+ if (
1170
+ int(height) != patch_pos_embed.shape[-2]
1171
+ or int(width) != patch_pos_embed.shape[-1]
1172
+ ):
1173
+ raise ValueError(
1174
+ "Width or height does not match with the interpolated position embeddings"
1175
+ )
1176
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1177
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
1178
+
1179
+ def forward(
1180
+ self,
1181
+ pixel_values: torch.Tensor,
1182
+ ) -> torch.Tensor:
1183
+ batch_size, _, height, width = pixel_values.shape
1184
+ patch_embeddings = self.patch_embeddings(pixel_values)
1185
+ embeddings = patch_embeddings
1186
+
1187
+ # add the [CLS] token to the embedded patch tokens
1188
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1189
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
1190
+
1191
+ # add positional encoding to each token
1192
+ embeddings = embeddings + self.interpolate_pos_encoding(
1193
+ embeddings, height, width
1194
+ )
1195
+
1196
+ return embeddings
sf3d/models/tokenizers/image.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from sf3d.models.tokenizers.dinov2 import Dinov2Model
11
+ from sf3d.models.transformers.attention import Modulation
12
+ from sf3d.models.utils import BaseModule
13
+
14
+
15
+ class DINOV2SingleImageTokenizer(BaseModule):
16
+ @dataclass
17
+ class Config(BaseModule.Config):
18
+ pretrained_model_name_or_path: str = "facebook/dinov2-large"
19
+ width: int = 512
20
+ height: int = 512
21
+ modulation_cond_dim: int = 768
22
+
23
+ cfg: Config
24
+
25
+ def configure(self) -> None:
26
+ self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
27
+
28
+ for p in self.model.parameters():
29
+ p.requires_grad_(False)
30
+ self.model.eval()
31
+
32
+ self.model.set_gradient_checkpointing(False)
33
+
34
+ # add modulation
35
+ modulations = []
36
+ for layer in self.model.encoder.layer:
37
+ norm1_modulation = Modulation(
38
+ self.model.config.hidden_size,
39
+ self.cfg.modulation_cond_dim,
40
+ zero_init=True,
41
+ single_layer=True,
42
+ )
43
+ norm2_modulation = Modulation(
44
+ self.model.config.hidden_size,
45
+ self.cfg.modulation_cond_dim,
46
+ zero_init=True,
47
+ single_layer=True,
48
+ )
49
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
50
+ modulations += [norm1_modulation, norm2_modulation]
51
+ self.modulations = nn.ModuleList(modulations)
52
+
53
+ self.register_buffer(
54
+ "image_mean",
55
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
56
+ persistent=False,
57
+ )
58
+ self.register_buffer(
59
+ "image_std",
60
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
61
+ persistent=False,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ images: Float[Tensor, "B *N C H W"],
67
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
68
+ **kwargs,
69
+ ) -> Float[Tensor, "B *N Ct Nt"]:
70
+ model = self.model
71
+
72
+ packed = False
73
+ if images.ndim == 4:
74
+ packed = True
75
+ images = images.unsqueeze(1)
76
+ if modulation_cond is not None:
77
+ assert modulation_cond.ndim == 2
78
+ modulation_cond = modulation_cond.unsqueeze(1)
79
+
80
+ batch_size, n_input_views = images.shape[:2]
81
+ images = (images - self.image_mean) / self.image_std
82
+ out = model(
83
+ rearrange(images, "B N C H W -> (B N) C H W"),
84
+ modulation_cond=(
85
+ rearrange(modulation_cond, "B N Cc -> (B N) Cc")
86
+ if modulation_cond is not None
87
+ else None
88
+ ),
89
+ )
90
+ local_features = out.last_hidden_state
91
+ local_features = local_features.permute(0, 2, 1)
92
+ local_features = rearrange(
93
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
94
+ )
95
+ if packed:
96
+ local_features = local_features.squeeze(1)
97
+
98
+ return local_features
99
+
100
+ def detokenize(self, *args, **kwargs):
101
+ raise NotImplementedError
sf3d/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from sf3d.models.utils import BaseModule
11
+
12
+
13
+ class TriplaneLearnablePositionalEmbedding(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ plane_size: int = 96
17
+ num_channels: int = 1024
18
+
19
+ cfg: Config
20
+
21
+ def configure(self) -> None:
22
+ self.embeddings = nn.Parameter(
23
+ torch.randn(
24
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
25
+ dtype=torch.float32,
26
+ )
27
+ * 1
28
+ / math.sqrt(self.cfg.num_channels)
29
+ )
30
+
31
+ def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
32
+ return rearrange(
33
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
34
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
35
+ )
36
+
37
+ def detokenize(
38
+ self, tokens: Float[Tensor, "B Ct Nt"]
39
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
40
+ batch_size, Ct, Nt = tokens.shape
41
+ assert Nt == self.cfg.plane_size**2 * 3
42
+ assert Ct == self.cfg.num_channels
43
+ return rearrange(
44
+ tokens,
45
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
46
+ Np=3,
47
+ Hp=self.cfg.plane_size,
48
+ Wp=self.cfg.plane_size,
49
+ )
sf3d/models/transformers/attention.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Modulation(nn.Module):
6
+ def __init__(
7
+ self,
8
+ embedding_dim: int,
9
+ condition_dim: int,
10
+ zero_init: bool = False,
11
+ single_layer: bool = False,
12
+ ):
13
+ super().__init__()
14
+ self.silu = nn.SiLU()
15
+ if single_layer:
16
+ self.linear1 = nn.Identity()
17
+ else:
18
+ self.linear1 = nn.Linear(condition_dim, condition_dim)
19
+
20
+ self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
21
+
22
+ # Only zero init the last linear layer
23
+ if zero_init:
24
+ nn.init.zeros_(self.linear2.weight)
25
+ nn.init.zeros_(self.linear2.bias)
26
+
27
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
28
+ emb = self.linear2(self.silu(self.linear1(condition)))
29
+ scale, shift = torch.chunk(emb, 2, dim=1)
30
+ x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+ return x
sf3d/models/transformers/backbone.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from sf3d.models.utils import BaseModule
9
+
10
+
11
+ class GEGLU(nn.Module):
12
+ r"""
13
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
14
+
15
+ Parameters:
16
+ dim_in (`int`): The number of channels in the input.
17
+ dim_out (`int`): The number of channels in the output.
18
+ """
19
+
20
+ def __init__(self, dim_in: int, dim_out: int):
21
+ super().__init__()
22
+ self.proj = nn.Linear(dim_in, dim_out * 2)
23
+
24
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
25
+ if gate.device.type != "mps":
26
+ return F.gelu(gate)
27
+ # mps: gelu is not implemented for float16
28
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
29
+
30
+ def forward(self, hidden_states, scale: float = 1.0):
31
+ args = ()
32
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
33
+ return hidden_states * self.gelu(gate)
34
+
35
+
36
+ class CrossAttention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim,
40
+ kv_dim=None,
41
+ num_heads=16,
42
+ qkv_bias=False,
43
+ attn_drop=0.0,
44
+ proj_drop=0.0,
45
+ ):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+ kv_dim = dim if not kv_dim else kv_dim
51
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
52
+ self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
53
+ self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
54
+ self.attn_drop = attn_drop
55
+ self.proj = nn.Linear(dim, dim)
56
+ self.proj_drop = nn.Dropout(proj_drop)
57
+
58
+ def forward(self, x_q, x_kv):
59
+ B, N_q, C = x_q.shape
60
+ B, N_kv, _ = x_kv.shape
61
+ # [B, N_q, C] -> [B, N_q, H, C/H]
62
+ q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
63
+ # [B, N_kv, C] -> [B, N_kv, H, C/H]
64
+ k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
65
+ v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
66
+
67
+ # attention
68
+ x = torch.nn.functional.scaled_dot_product_attention(
69
+ q.permute(0, 2, 1, 3),
70
+ k.permute(0, 2, 1, 3),
71
+ v.permute(0, 2, 1, 3),
72
+ attn_mask=None,
73
+ dropout_p=self.attn_drop,
74
+ scale=self.scale,
75
+ ).permute(0, 2, 1, 3)
76
+
77
+ # [B, N_q, H, C/H] -> [B, N_q, C]
78
+ x = x.reshape(B, N_q, C)
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
84
+ class FeedForward(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: Optional[int] = None,
89
+ mult: int = 4,
90
+ dropout: float = 0.0,
91
+ ):
92
+ super().__init__()
93
+ inner_dim = int(dim * mult)
94
+ dim_out = dim_out if dim_out is not None else dim
95
+ act_fn = GEGLU(dim, inner_dim)
96
+ self.net = nn.ModuleList([])
97
+ self.net.append(act_fn)
98
+ self.net.append(nn.Dropout(dropout))
99
+ self.net.append(nn.Linear(inner_dim, dim_out))
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ for module in self.net:
103
+ x = module(x)
104
+ return x
105
+
106
+
107
+ class BasicBlock(nn.Module):
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ kv_dim: Optional[int] = None,
112
+ num_heads: int = 16,
113
+ qkv_bias: bool = False,
114
+ attn_drop: float = 0.0,
115
+ proj_drop: float = 0.0,
116
+ ff_drop: float = 0.0,
117
+ ):
118
+ super().__init__()
119
+ self.norm1 = nn.LayerNorm(dim)
120
+ self.attn1 = CrossAttention(
121
+ dim,
122
+ kv_dim=dim,
123
+ num_heads=num_heads,
124
+ qkv_bias=qkv_bias,
125
+ attn_drop=attn_drop,
126
+ proj_drop=proj_drop,
127
+ )
128
+ self.norm2 = nn.LayerNorm(dim)
129
+ self.attn2 = CrossAttention(
130
+ dim,
131
+ kv_dim=kv_dim,
132
+ num_heads=num_heads,
133
+ qkv_bias=qkv_bias,
134
+ attn_drop=attn_drop,
135
+ proj_drop=proj_drop,
136
+ )
137
+ self.norm3 = nn.LayerNorm(dim)
138
+ self.ff = FeedForward(dim, dropout=ff_drop)
139
+
140
+ def forward(self, z, x):
141
+ z_norm = self.norm1(z)
142
+ z = z + self.attn1(z_norm, z_norm)
143
+ # TODO: do we need to have the second attention when x is None?
144
+ z_norm = self.norm2(z)
145
+ z = z + self.attn2(z_norm, x if x is not None else z_norm)
146
+ z_norm = self.norm3(z)
147
+ z = z + self.ff(z_norm)
148
+ return z
149
+
150
+
151
+ class SingleStreamTransformer(BaseModule):
152
+ @dataclass
153
+ class Config(BaseModule.Config):
154
+ num_attention_heads: int = 16
155
+ attention_head_dim: int = 88
156
+ in_channels: Optional[int] = None
157
+ out_channels: Optional[int] = None
158
+ num_layers: int = 16
159
+ dropout: float = 0.0
160
+ norm_num_groups: int = 32
161
+ cross_attention_dim: Optional[int] = None
162
+ attention_bias: bool = False
163
+
164
+ cfg: Config
165
+
166
+ def configure(self) -> None:
167
+ self.num_attention_heads = self.cfg.num_attention_heads
168
+ self.attention_head_dim = self.cfg.attention_head_dim
169
+ inner_dim = self.num_attention_heads * self.attention_head_dim
170
+
171
+ # Define input layers
172
+ self.norm = torch.nn.GroupNorm(
173
+ num_groups=self.cfg.norm_num_groups,
174
+ num_channels=self.cfg.in_channels,
175
+ eps=1e-6,
176
+ affine=True,
177
+ )
178
+ self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim)
179
+
180
+ # Define transformers blocks
181
+ self.transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicBlock(
184
+ inner_dim,
185
+ kv_dim=self.cfg.cross_attention_dim,
186
+ num_heads=self.num_attention_heads,
187
+ qkv_bias=self.cfg.attention_bias,
188
+ proj_drop=self.cfg.dropout,
189
+ ff_drop=self.cfg.dropout,
190
+ )
191
+ for d in range(self.cfg.num_layers)
192
+ ]
193
+ )
194
+
195
+ # 4. Define output layers
196
+ self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels)
197
+
198
+ def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
199
+ residual = hidden_states
200
+ hidden_states = self.norm(hidden_states)
201
+ hidden_states = hidden_states.permute(0, 2, 1)
202
+ hidden_states = self.proj_in(hidden_states)
203
+ for block in self.transformer_blocks:
204
+ hidden_states = block(hidden_states, encoder_hidden_states)
205
+ hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous()
206
+ # TODO: do we really need to add the residual?
207
+ hidden_states = hidden_states + residual
208
+ return hidden_states
209
+
210
+
211
+ class FuseBlock(nn.Module):
212
+ """
213
+ Fuse X in to Z with cross attention
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ dim_z: int,
219
+ dim_x: int,
220
+ num_heads: int = 16,
221
+ qkv_bias: bool = False,
222
+ attn_drop: float = 0.0,
223
+ proj_drop: float = 0.0,
224
+ ff_drop: float = 0.0,
225
+ norm_x_input: bool = True,
226
+ ):
227
+ super().__init__()
228
+ self.norm_x_input = norm_x_input
229
+ if self.norm_x_input:
230
+ self.norm_x = nn.LayerNorm(dim_x)
231
+ self.attn = CrossAttention(
232
+ dim_z,
233
+ kv_dim=dim_x,
234
+ num_heads=num_heads,
235
+ qkv_bias=qkv_bias,
236
+ attn_drop=attn_drop,
237
+ proj_drop=proj_drop,
238
+ )
239
+ self.norm_z1 = nn.LayerNorm(dim_z)
240
+ self.norm_z2 = nn.LayerNorm(dim_z)
241
+ self.ff = FeedForward(dim_z, dropout=ff_drop)
242
+
243
+ def forward(self, z, x):
244
+ # TODO: do we need to normalize x?
245
+ z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x)
246
+ z = z + self.ff(self.norm_z2(z))
247
+ return z
248
+
249
+
250
+ @torch.no_grad()
251
+ def get_triplane_attention_mask(res):
252
+ N = 3 * res * res
253
+ attn_mask = torch.zeros(3, res, res, 3, res, res)
254
+
255
+ i, j = torch.meshgrid(torch.arange(res), torch.arange(res))
256
+
257
+ attn_mask[0, i, j, 1, i, :] = 1.0
258
+ attn_mask[0, i, j, 2, j, :] = 1.0
259
+ attn_mask[1, i, j, 0, i, :] = 1.0
260
+ attn_mask[1, i, j, 2, :, j] = 1.0
261
+ attn_mask[2, i, j, 0, :, i] = 1.0
262
+ attn_mask[2, i, j, 1, :, j] = 1.0
263
+ attn_mask = attn_mask.bool()
264
+
265
+ attn_bias = torch.empty_like(attn_mask, dtype=torch.float)
266
+ attn_bias.masked_fill_(attn_mask, 0.0)
267
+ attn_bias.masked_fill_(~attn_mask, float("-inf"))
268
+
269
+ return attn_bias.reshape(N, N)
270
+
271
+
272
+ class TriplaneAttention(nn.Module):
273
+ def __init__(
274
+ self,
275
+ dim: int,
276
+ resolution: int,
277
+ num_heads: int = 16,
278
+ qkv_bias: bool = False,
279
+ attn_drop: float = 0.0,
280
+ proj_drop: float = 0.0,
281
+ full_attention: bool = False,
282
+ ):
283
+ super().__init__()
284
+ self.num_heads = num_heads
285
+ head_dim = dim // num_heads
286
+ self.scale = head_dim**-0.5
287
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
288
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
289
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
290
+ self.attn_drop = attn_drop
291
+ self.proj = nn.Linear(dim, dim)
292
+ self.proj_drop = nn.Dropout(proj_drop)
293
+
294
+ self.resolution = resolution
295
+ self.full_attention = full_attention
296
+ self.attn_mask = (
297
+ get_triplane_attention_mask(resolution) if not full_attention else None
298
+ )
299
+
300
+ def forward(self, x):
301
+ B, N, C = x.shape
302
+ # [B, N, C] -> [B, N, H, C/H]
303
+ q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads)
304
+ k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads)
305
+ v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads)
306
+
307
+ # detokenize the planes
308
+ assert N == self.resolution**2 * 3
309
+ attn_bias = (
310
+ self.attn_mask.to(q)
311
+ .unsqueeze(0)
312
+ .unsqueeze(0)
313
+ .expand(B, self.num_heads, -1, -1)
314
+ if not self.full_attention
315
+ else None
316
+ )
317
+
318
+ # full attention
319
+ x = torch.nn.functional.scaled_dot_product_attention(
320
+ q.permute(0, 2, 1, 3),
321
+ k.permute(0, 2, 1, 3),
322
+ v.permute(0, 2, 1, 3),
323
+ attn_mask=attn_bias,
324
+ dropout_p=self.attn_drop,
325
+ scale=self.scale,
326
+ ).permute(0, 2, 1, 3)
327
+
328
+ # [B, N_q, H, C/H] -> [B, N_q, C]
329
+ x = x.reshape(B, N, C)
330
+ x = self.proj(x)
331
+ x = self.proj_drop(x)
332
+ return x
333
+
334
+
335
+ class TwoStreamBlock(nn.Module):
336
+ def __init__(
337
+ self,
338
+ dim_latent: int,
339
+ dim_input: int,
340
+ num_basic_blocks: int = 4,
341
+ num_heads: int = 16,
342
+ qkv_bias: bool = False,
343
+ attn_drop: float = 0.0,
344
+ proj_drop: float = 0.0,
345
+ ff_drop: float = 0.0,
346
+ norm_x_input: bool = True,
347
+ dim_cross: Optional[int] = None,
348
+ ):
349
+ super().__init__()
350
+
351
+ # Define the fuse block that fuse the input into the latent
352
+ self.fuse_block_in = FuseBlock(
353
+ dim_latent,
354
+ dim_input,
355
+ num_heads=num_heads,
356
+ qkv_bias=qkv_bias,
357
+ attn_drop=attn_drop,
358
+ proj_drop=proj_drop,
359
+ ff_drop=ff_drop,
360
+ norm_x_input=norm_x_input,
361
+ )
362
+
363
+ # Define the transformer block that process the latent
364
+ self.transformer_block = nn.ModuleList(
365
+ [
366
+ BasicBlock(
367
+ dim_latent,
368
+ kv_dim=dim_cross,
369
+ num_heads=num_heads,
370
+ qkv_bias=qkv_bias,
371
+ proj_drop=proj_drop,
372
+ ff_drop=ff_drop,
373
+ )
374
+ for _ in range(num_basic_blocks)
375
+ ]
376
+ )
377
+
378
+ # Define the fuse block that fuse the latent into the input
379
+ self.fuse_block_out = FuseBlock(
380
+ dim_input,
381
+ dim_latent,
382
+ num_heads=num_heads,
383
+ qkv_bias=qkv_bias,
384
+ attn_drop=attn_drop,
385
+ proj_drop=proj_drop,
386
+ ff_drop=ff_drop,
387
+ norm_x_input=norm_x_input,
388
+ )
389
+
390
+ def forward(self, latent, input, cross_input):
391
+ latent = self.fuse_block_in(latent, input)
392
+ for block in self.transformer_block:
393
+ latent = block(latent, cross_input)
394
+ input = self.fuse_block_out(input, latent)
395
+ return latent, input
396
+
397
+
398
+ class TwoStreamInterleaveTransformer(BaseModule):
399
+ @dataclass
400
+ class Config(BaseModule.Config):
401
+ num_attention_heads: int = 16
402
+ attention_head_dim: int = 64
403
+ raw_triplane_channels: int = 1024
404
+ triplane_channels: int = 1024
405
+ raw_image_channels: int = 1024
406
+ num_latents: int = 1792
407
+ num_blocks: int = 4
408
+ num_basic_blocks: int = 3
409
+ dropout: float = 0.0
410
+ latent_init_std: float = 0.02
411
+ norm_num_groups: int = 32
412
+ attention_bias: bool = False
413
+ norm_x_input: bool = False
414
+ cross_attention_dim: int = 1024
415
+ mix_latent: bool = True
416
+
417
+ cfg: Config
418
+
419
+ def configure(self) -> None:
420
+ self.mix_latent = self.cfg.mix_latent
421
+
422
+ # Define the dimensions
423
+ self.num_attention_heads = self.cfg.num_attention_heads
424
+ self.attention_head_dim = self.cfg.attention_head_dim
425
+ self.num_latents = self.cfg.num_latents
426
+ self.latent_dim = self.num_attention_heads * self.attention_head_dim
427
+
428
+ # Define input layers
429
+ if self.cfg.norm_num_groups > 0:
430
+ self.norm_triplane = torch.nn.GroupNorm(
431
+ num_groups=self.cfg.norm_num_groups,
432
+ num_channels=self.cfg.raw_triplane_channels,
433
+ eps=1e-6,
434
+ affine=True,
435
+ )
436
+ else:
437
+ self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels)
438
+ self.proj_triplane = nn.Linear(
439
+ self.cfg.raw_triplane_channels, self.cfg.triplane_channels
440
+ )
441
+ if self.mix_latent:
442
+ self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels)
443
+ self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim)
444
+ self.norm_latent = nn.LayerNorm(self.latent_dim)
445
+ self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim)
446
+
447
+ # Define the latents
448
+ self.latent_init = nn.Parameter(
449
+ torch.zeros(1, self.num_latents, self.latent_dim)
450
+ )
451
+ nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std)
452
+
453
+ # Define the transformer blocks
454
+ self.main_blocks = nn.ModuleList(
455
+ [
456
+ TwoStreamBlock(
457
+ self.latent_dim,
458
+ self.cfg.triplane_channels,
459
+ num_basic_blocks=self.cfg.num_basic_blocks,
460
+ num_heads=self.num_attention_heads,
461
+ qkv_bias=self.cfg.attention_bias,
462
+ proj_drop=self.cfg.dropout,
463
+ ff_drop=self.cfg.dropout,
464
+ norm_x_input=self.cfg.norm_x_input,
465
+ dim_cross=self.cfg.cross_attention_dim,
466
+ )
467
+ for _ in range(self.cfg.num_blocks)
468
+ ]
469
+ )
470
+
471
+ # 4. Define output layers
472
+ self.proj_out = nn.Linear(
473
+ self.cfg.triplane_channels, self.cfg.raw_triplane_channels
474
+ )
475
+
476
+ def forward(self, hidden_states, encoder_hidden_states, **kwargs):
477
+ # hidden_states: [B, triplane_dim, N_triplane] is triplane tokens
478
+ # encoder_hidden_states: [B, N_image, image_dim] is the image tokens
479
+ if isinstance(self.norm_triplane, nn.GroupNorm):
480
+ triplane_tokens = self.norm_triplane(hidden_states)
481
+ triplane_tokens = triplane_tokens.permute(
482
+ 0, 2, 1
483
+ ) # [B, N_triplane, triplane_dim]
484
+ elif isinstance(self.norm_triplane, nn.LayerNorm):
485
+ triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1))
486
+ else:
487
+ raise ValueError("Unknown normalization layer")
488
+ triplane_tokens = self.proj_triplane(triplane_tokens)
489
+ if self.mix_latent:
490
+ image_tokens = self.norm_image(
491
+ encoder_hidden_states
492
+ ) # [B, N_image, image_dim]
493
+ image_tokens = self.proj_image(image_tokens)
494
+ init_latents = self.latent_init.expand(
495
+ hidden_states.shape[0], -1, -1
496
+ ) # [B, N_latent_init, latent_dim]
497
+ init_latents = self.norm_latent(init_latents)
498
+ init_latents = self.proj_latent(init_latents)
499
+ if self.mix_latent:
500
+ latent_tokens = torch.cat(
501
+ [image_tokens, init_latents], dim=1
502
+ ) # [B, N_latent, latent_dim]
503
+ else:
504
+ latent_tokens = init_latents
505
+
506
+ # forward the main blocks
507
+ for block in self.main_blocks:
508
+ latent_tokens, triplane_tokens = block(
509
+ latent_tokens, triplane_tokens, encoder_hidden_states
510
+ )
511
+
512
+ # project the triplane tokens back to the original dimension
513
+ triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous()
514
+ triplane_tokens = triplane_tokens + hidden_states
515
+ return triplane_tokens
sf3d/models/utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import importlib
3
+ from dataclasses import dataclass
4
+ from typing import Any, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from jaxtyping import Float, Int, Num
12
+ from omegaconf import DictConfig, OmegaConf
13
+ from torch import Tensor
14
+
15
+
16
+ class BaseModule(nn.Module):
17
+ @dataclass
18
+ class Config:
19
+ pass
20
+
21
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
22
+
23
+ def __init__(
24
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
25
+ ) -> None:
26
+ super().__init__()
27
+ self.cfg = parse_structured(self.Config, cfg)
28
+ self.configure(*args, **kwargs)
29
+
30
+ def configure(self, *args, **kwargs) -> None:
31
+ raise NotImplementedError
32
+
33
+
34
+ def find_class(cls_string):
35
+ module_string = ".".join(cls_string.split(".")[:-1])
36
+ cls_name = cls_string.split(".")[-1]
37
+ module = importlib.import_module(module_string, package=None)
38
+ cls = getattr(module, cls_name)
39
+ return cls
40
+
41
+
42
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
43
+ # Check if cfg.keys are in fields
44
+ cfg_ = cfg.copy()
45
+ keys = list(cfg_.keys())
46
+
47
+ field_names = {f.name for f in dataclasses.fields(fields)}
48
+ for key in keys:
49
+ # This is helpful when swapping out modules from CLI
50
+ if key not in field_names:
51
+ print(f"Ignoring {key} as it's not supported by {fields}")
52
+ cfg_.pop(key)
53
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
54
+ return scfg
55
+
56
+
57
+ EPS_DTYPE = {
58
+ torch.float16: 1e-4,
59
+ torch.bfloat16: 1e-4,
60
+ torch.float32: 1e-7,
61
+ torch.float64: 1e-8,
62
+ }
63
+
64
+
65
+ def dot(x, y, dim=-1):
66
+ return torch.sum(x * y, dim, keepdim=True)
67
+
68
+
69
+ def reflect(x, n):
70
+ return x - 2 * dot(x, n) * n
71
+
72
+
73
+ def normalize(x, dim=-1, eps=None):
74
+ if eps is None:
75
+ eps = EPS_DTYPE[x.dtype]
76
+ return F.normalize(x, dim=dim, p=2, eps=eps)
77
+
78
+
79
+ ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
80
+
81
+
82
+ def scale_tensor(
83
+ dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
84
+ ):
85
+ if inp_scale is None:
86
+ inp_scale = (0, 1)
87
+ if tgt_scale is None:
88
+ tgt_scale = (0, 1)
89
+ if isinstance(tgt_scale, Tensor):
90
+ assert dat.shape[-1] == tgt_scale.shape[-1]
91
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
92
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
93
+ return dat
94
+
95
+
96
+ def dilate_fill(img, mask, iterations=10):
97
+ oldMask = mask.float()
98
+ oldImg = img
99
+
100
+ mask_kernel = torch.ones(
101
+ (1, 1, 3, 3),
102
+ dtype=oldMask.dtype,
103
+ device=oldMask.device,
104
+ )
105
+
106
+ for i in range(iterations):
107
+ newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
108
+
109
+ # Fill the extension with mean color of old valid regions
110
+ img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
111
+ mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
112
+ new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
113
+
114
+ # Average color of the valid region
115
+ mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
116
+ 2
117
+ )
118
+ # Extend it to the new region
119
+ fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
120
+
121
+ mask_conv = F.conv2d(
122
+ newMask, mask_kernel, padding=1
123
+ ) # Get the sum for each kernel patch
124
+ newImg = F.fold(
125
+ fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
126
+ ) / mask_conv.clamp(1)
127
+
128
+ diffMask = newMask - oldMask
129
+
130
+ oldMask = newMask
131
+ oldImg = torch.lerp(oldImg, newImg, diffMask)
132
+
133
+ return oldImg
134
+
135
+
136
+ def float32_to_uint8_np(
137
+ x: Float[np.ndarray, "*B H W C"],
138
+ dither: bool = True,
139
+ dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
140
+ dither_strength: float = 1.0,
141
+ ) -> Int[np.ndarray, "*B H W C"]:
142
+ if dither:
143
+ dither = (
144
+ dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
145
+ )
146
+ if dither_mask is not None:
147
+ dither = dither * dither_mask
148
+ return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
149
+ return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
150
+
151
+
152
+ def convert_data(data):
153
+ if data is None:
154
+ return None
155
+ elif isinstance(data, np.ndarray):
156
+ return data
157
+ elif isinstance(data, torch.Tensor):
158
+ if data.dtype in [torch.float16, torch.bfloat16]:
159
+ data = data.float()
160
+ return data.detach().cpu().numpy()
161
+ elif isinstance(data, list):
162
+ return [convert_data(d) for d in data]
163
+ elif isinstance(data, dict):
164
+ return {k: convert_data(v) for k, v in data.items()}
165
+ else:
166
+ raise TypeError(
167
+ "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
168
+ type(data),
169
+ )
170
+
171
+
172
+ class ImageProcessor:
173
+ def convert_and_resize(
174
+ self,
175
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
176
+ size: int,
177
+ ):
178
+ if isinstance(image, PIL.Image.Image):
179
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
180
+ elif isinstance(image, np.ndarray):
181
+ if image.dtype == np.uint8:
182
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
183
+ else:
184
+ image = torch.from_numpy(image)
185
+ elif isinstance(image, torch.Tensor):
186
+ pass
187
+
188
+ batched = image.ndim == 4
189
+
190
+ if not batched:
191
+ image = image[None, ...]
192
+ image = F.interpolate(
193
+ image.permute(0, 3, 1, 2),
194
+ (size, size),
195
+ mode="bilinear",
196
+ align_corners=False,
197
+ antialias=True,
198
+ ).permute(0, 2, 3, 1)
199
+ if not batched:
200
+ image = image[0]
201
+ return image
202
+
203
+ def __call__(
204
+ self,
205
+ image: Union[
206
+ PIL.Image.Image,
207
+ np.ndarray,
208
+ torch.FloatTensor,
209
+ List[PIL.Image.Image],
210
+ List[np.ndarray],
211
+ List[torch.FloatTensor],
212
+ ],
213
+ size: int,
214
+ ) -> Any:
215
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
216
+ image = self.convert_and_resize(image, size)
217
+ else:
218
+ if not isinstance(image, list):
219
+ image = [image]
220
+ image = [self.convert_and_resize(im, size) for im in image]
221
+ image = torch.stack(image, dim=0)
222
+ return image
223
+
224
+
225
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
226
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
227
+ intrinsic = np.identity(3, dtype=np.float32)
228
+ intrinsic[0, 0] = focal_length
229
+ intrinsic[1, 1] = focal_length
230
+ intrinsic[0, 2] = W / 2.0
231
+ intrinsic[1, 2] = H / 2.0
232
+
233
+ if bs > 0:
234
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
235
+
236
+ return torch.from_numpy(intrinsic)
sf3d/system.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import nullcontext
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, List, Literal, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import trimesh
10
+ from einops import rearrange
11
+ from huggingface_hub import hf_hub_download
12
+ from jaxtyping import Float
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+ from safetensors.torch import load_model
16
+ from torch import Tensor
17
+
18
+ from sf3d.models.isosurface import MarchingTetrahedraHelper
19
+ from sf3d.models.mesh import Mesh
20
+ from sf3d.models.utils import (
21
+ BaseModule,
22
+ ImageProcessor,
23
+ convert_data,
24
+ dilate_fill,
25
+ find_class,
26
+ float32_to_uint8_np,
27
+ normalize,
28
+ scale_tensor,
29
+ )
30
+ from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w, get_device
31
+
32
+ try:
33
+ from texture_baker import TextureBaker
34
+ except ImportError:
35
+ import logging
36
+
37
+ logging.warning(
38
+ "Could not import texture_baker. Please install it via `pip install texture-baker/`"
39
+ )
40
+ # Exit early to avoid further errors
41
+ raise ImportError("texture_baker not found")
42
+
43
+
44
+ class SF3D(BaseModule):
45
+ @dataclass
46
+ class Config(BaseModule.Config):
47
+ cond_image_size: int
48
+ isosurface_resolution: int
49
+ isosurface_threshold: float = 10.0
50
+ radius: float = 1.0
51
+ background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
52
+ default_fovy_deg: float = 40.0
53
+ default_distance: float = 1.6
54
+
55
+ camera_embedder_cls: str = ""
56
+ camera_embedder: dict = field(default_factory=dict)
57
+
58
+ image_tokenizer_cls: str = ""
59
+ image_tokenizer: dict = field(default_factory=dict)
60
+
61
+ tokenizer_cls: str = ""
62
+ tokenizer: dict = field(default_factory=dict)
63
+
64
+ backbone_cls: str = ""
65
+ backbone: dict = field(default_factory=dict)
66
+
67
+ post_processor_cls: str = ""
68
+ post_processor: dict = field(default_factory=dict)
69
+
70
+ decoder_cls: str = ""
71
+ decoder: dict = field(default_factory=dict)
72
+
73
+ image_estimator_cls: str = ""
74
+ image_estimator: dict = field(default_factory=dict)
75
+
76
+ global_estimator_cls: str = ""
77
+ global_estimator: dict = field(default_factory=dict)
78
+
79
+ cfg: Config
80
+
81
+ @classmethod
82
+ def from_pretrained(
83
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
84
+ ):
85
+ if os.path.isdir(pretrained_model_name_or_path):
86
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
87
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
88
+ else:
89
+ config_path = hf_hub_download(
90
+ repo_id=pretrained_model_name_or_path, filename=config_name
91
+ )
92
+ weight_path = hf_hub_download(
93
+ repo_id=pretrained_model_name_or_path, filename=weight_name
94
+ )
95
+
96
+ cfg = OmegaConf.load(config_path)
97
+ OmegaConf.resolve(cfg)
98
+ model = cls(cfg)
99
+ load_model(model, weight_path)
100
+ return model
101
+
102
+ @property
103
+ def device(self):
104
+ return next(self.parameters()).device
105
+
106
+ def configure(self):
107
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
108
+ self.cfg.image_tokenizer
109
+ )
110
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
111
+ self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
112
+ self.cfg.camera_embedder
113
+ )
114
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
115
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
116
+ self.cfg.post_processor
117
+ )
118
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
119
+ self.image_estimator = find_class(self.cfg.image_estimator_cls)(
120
+ self.cfg.image_estimator
121
+ )
122
+ self.global_estimator = find_class(self.cfg.global_estimator_cls)(
123
+ self.cfg.global_estimator
124
+ )
125
+
126
+ self.bbox: Float[Tensor, "2 3"]
127
+ self.register_buffer(
128
+ "bbox",
129
+ torch.as_tensor(
130
+ [
131
+ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
132
+ [self.cfg.radius, self.cfg.radius, self.cfg.radius],
133
+ ],
134
+ dtype=torch.float32,
135
+ ),
136
+ )
137
+ self.isosurface_helper = MarchingTetrahedraHelper(
138
+ self.cfg.isosurface_resolution,
139
+ os.path.join(
140
+ os.path.dirname(__file__),
141
+ "..",
142
+ "load",
143
+ "tets",
144
+ f"{self.cfg.isosurface_resolution}_tets.npz",
145
+ ),
146
+ )
147
+
148
+ self.baker = TextureBaker()
149
+ self.image_processor = ImageProcessor()
150
+
151
+ def triplane_to_meshes(
152
+ self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
153
+ ) -> list[Mesh]:
154
+ meshes = []
155
+ for i in range(triplanes.shape[0]):
156
+ triplane = triplanes[i]
157
+ grid_vertices = scale_tensor(
158
+ self.isosurface_helper.grid_vertices.to(triplanes.device),
159
+ self.isosurface_helper.points_range,
160
+ self.bbox,
161
+ )
162
+
163
+ values = self.query_triplane(grid_vertices, triplane)
164
+ decoded = self.decoder(values, include=["vertex_offset", "density"])
165
+ sdf = decoded["density"] - self.cfg.isosurface_threshold
166
+
167
+ deform = decoded["vertex_offset"].squeeze(0)
168
+
169
+ mesh: Mesh = self.isosurface_helper(
170
+ sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
171
+ )
172
+ mesh.v_pos = scale_tensor(
173
+ mesh.v_pos, self.isosurface_helper.points_range, self.bbox
174
+ )
175
+
176
+ meshes.append(mesh)
177
+
178
+ return meshes
179
+
180
+ def query_triplane(
181
+ self,
182
+ positions: Float[Tensor, "*B N 3"],
183
+ triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
184
+ ) -> Float[Tensor, "*B N F"]:
185
+ batched = positions.ndim == 3
186
+ if not batched:
187
+ # no batch dimension
188
+ triplanes = triplanes[None, ...]
189
+ positions = positions[None, ...]
190
+ assert triplanes.ndim == 5 and positions.ndim == 3
191
+
192
+ positions = scale_tensor(
193
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
194
+ )
195
+
196
+ indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
197
+ (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
198
+ dim=-3,
199
+ ).to(triplanes.dtype)
200
+ out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
201
+ rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
202
+ rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
203
+ align_corners=True,
204
+ mode="bilinear",
205
+ )
206
+ out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
207
+
208
+ return out
209
+
210
+ def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
211
+ # if batch[rgb_cond] is only one view, add a view dimension
212
+ if len(batch["rgb_cond"].shape) == 4:
213
+ batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
214
+ batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
215
+ batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
216
+ batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
217
+ batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
218
+
219
+ batch_size, n_input_views = batch["rgb_cond"].shape[:2]
220
+
221
+ camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
222
+ camera_embeds = self.camera_embedder(**batch)
223
+
224
+ input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
225
+ rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
226
+ modulation_cond=camera_embeds,
227
+ )
228
+
229
+ input_image_tokens = rearrange(
230
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
231
+ )
232
+
233
+ tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
234
+
235
+ tokens = self.backbone(
236
+ tokens,
237
+ encoder_hidden_states=input_image_tokens,
238
+ modulation_cond=None,
239
+ )
240
+
241
+ direct_codes = self.tokenizer.detokenize(tokens)
242
+ scene_codes = self.post_processor(direct_codes)
243
+ return scene_codes, direct_codes
244
+
245
+ def run_image(
246
+ self,
247
+ image: Union[Image.Image, List[Image.Image]],
248
+ bake_resolution: int,
249
+ remesh: Literal["none", "triangle", "quad"] = "none",
250
+ vertex_count: int = -1,
251
+ estimate_illumination: bool = False,
252
+ ) -> Tuple[Union[trimesh.Trimesh, List[trimesh.Trimesh]], dict[str, Any]]:
253
+ if isinstance(image, list):
254
+ rgb_cond = []
255
+ mask_cond = []
256
+ for img in image:
257
+ mask, rgb = self.prepare_image(img)
258
+ mask_cond.append(mask)
259
+ rgb_cond.append(rgb)
260
+ rgb_cond = torch.stack(rgb_cond, 0)
261
+ mask_cond = torch.stack(mask_cond, 0)
262
+ batch_size = rgb_cond.shape[0]
263
+ else:
264
+ mask_cond, rgb_cond = self.prepare_image(image)
265
+ batch_size = 1
266
+
267
+ c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
268
+ intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
269
+ self.cfg.default_fovy_deg,
270
+ self.cfg.cond_image_size,
271
+ self.cfg.cond_image_size,
272
+ )
273
+
274
+ batch = {
275
+ "rgb_cond": rgb_cond,
276
+ "mask_cond": mask_cond,
277
+ "c2w_cond": c2w_cond.view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1),
278
+ "intrinsic_cond": intrinsic.to(self.device)
279
+ .view(1, 1, 3, 3)
280
+ .repeat(batch_size, 1, 1, 1),
281
+ "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device)
282
+ .view(1, 1, 3, 3)
283
+ .repeat(batch_size, 1, 1, 1),
284
+ }
285
+
286
+ meshes, global_dict = self.generate_mesh(
287
+ batch, bake_resolution, remesh, vertex_count, estimate_illumination
288
+ )
289
+ if batch_size == 1:
290
+ return meshes[0], global_dict
291
+ else:
292
+ return meshes, global_dict
293
+
294
+ def prepare_image(self, image):
295
+ if image.mode != "RGBA":
296
+ raise ValueError("Image must be in RGBA mode")
297
+ img_cond = (
298
+ torch.from_numpy(
299
+ np.asarray(
300
+ image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
301
+ ).astype(np.float32)
302
+ / 255.0
303
+ )
304
+ .float()
305
+ .clip(0, 1)
306
+ .to(self.device)
307
+ )
308
+ mask_cond = img_cond[:, :, -1:]
309
+ rgb_cond = torch.lerp(
310
+ torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
311
+ img_cond[:, :, :3],
312
+ mask_cond,
313
+ )
314
+
315
+ return mask_cond, rgb_cond
316
+
317
+ def generate_mesh(
318
+ self,
319
+ batch,
320
+ bake_resolution: int,
321
+ remesh: Literal["none", "triangle", "quad"] = "none",
322
+ vertex_count: int = -1,
323
+ estimate_illumination: bool = False,
324
+ ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
325
+ batch["rgb_cond"] = self.image_processor(
326
+ batch["rgb_cond"], self.cfg.cond_image_size
327
+ )
328
+ batch["mask_cond"] = self.image_processor(
329
+ batch["mask_cond"], self.cfg.cond_image_size
330
+ )
331
+ scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
332
+
333
+ global_dict = {}
334
+ if self.image_estimator is not None:
335
+ global_dict.update(
336
+ self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
337
+ )
338
+ if self.global_estimator is not None and estimate_illumination:
339
+ global_dict.update(self.global_estimator(non_postprocessed_codes))
340
+
341
+ device = get_device()
342
+ with torch.no_grad():
343
+ with (
344
+ torch.autocast(device_type=device, enabled=False)
345
+ if "cuda" in device
346
+ else nullcontext()
347
+ ):
348
+ meshes = self.triplane_to_meshes(scene_codes)
349
+
350
+ rets = []
351
+ for i, mesh in enumerate(meshes):
352
+ # Check for empty mesh
353
+ if mesh.v_pos.shape[0] == 0:
354
+ rets.append(trimesh.Trimesh())
355
+ continue
356
+
357
+ if remesh == "triangle":
358
+ mesh = mesh.triangle_remesh(triangle_vertex_count=vertex_count)
359
+ elif remesh == "quad":
360
+ mesh = mesh.quad_remesh(quad_vertex_count=vertex_count)
361
+ else:
362
+ if vertex_count > 0:
363
+ print(
364
+ "Warning: vertex_count is ignored when remesh is none"
365
+ )
366
+
367
+ print("After Remesh", mesh.v_pos.shape[0], mesh.t_pos_idx.shape[0])
368
+ mesh.unwrap_uv()
369
+
370
+ # Build textures
371
+ rast = self.baker.rasterize(
372
+ mesh.v_tex, mesh.t_pos_idx, bake_resolution
373
+ )
374
+ bake_mask = self.baker.get_mask(rast)
375
+
376
+ pos_bake = self.baker.interpolate(
377
+ mesh.v_pos,
378
+ rast,
379
+ mesh.t_pos_idx,
380
+ )
381
+ gb_pos = pos_bake[bake_mask]
382
+
383
+ tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
384
+ decoded = self.decoder(
385
+ tri_query, exclude=["density", "vertex_offset"]
386
+ )
387
+
388
+ nrm = self.baker.interpolate(
389
+ mesh.v_nrm,
390
+ rast,
391
+ mesh.t_pos_idx,
392
+ )
393
+ gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
394
+ decoded["normal"] = gb_nrm
395
+
396
+ # Check if any keys in global_dict start with decoded_
397
+ for k, v in global_dict.items():
398
+ if k.startswith("decoder_"):
399
+ decoded[k.replace("decoder_", "")] = v[i]
400
+
401
+ mat_out = {
402
+ "albedo": decoded["features"],
403
+ "roughness": decoded["roughness"],
404
+ "metallic": decoded["metallic"],
405
+ "normal": normalize(decoded["perturb_normal"]),
406
+ "bump": None,
407
+ }
408
+
409
+ for k, v in mat_out.items():
410
+ if v is None:
411
+ continue
412
+ if v.shape[0] == 1:
413
+ # Skip and directly add a single value
414
+ mat_out[k] = v[0]
415
+ else:
416
+ f = torch.zeros(
417
+ bake_resolution,
418
+ bake_resolution,
419
+ v.shape[-1],
420
+ dtype=v.dtype,
421
+ device=v.device,
422
+ )
423
+ if v.shape == f.shape:
424
+ continue
425
+ if k == "normal":
426
+ # Use un-normalized tangents here so that larger smaller tris
427
+ # Don't effect the tangents that much
428
+ tng = self.baker.interpolate(
429
+ mesh.v_tng,
430
+ rast,
431
+ mesh.t_pos_idx,
432
+ )
433
+ gb_tng = tng[bake_mask]
434
+ gb_tng = F.normalize(gb_tng, dim=-1)
435
+ gb_btng = F.normalize(
436
+ torch.cross(gb_nrm, gb_tng, dim=-1), dim=-1
437
+ )
438
+ normal = F.normalize(mat_out["normal"], dim=-1)
439
+
440
+ # Create tangent space matrix and transform normal
441
+ tangent_matrix = torch.stack(
442
+ [gb_tng, gb_btng, gb_nrm], dim=-1
443
+ )
444
+ normal_tangent = torch.bmm(
445
+ tangent_matrix.transpose(1, 2), normal.unsqueeze(-1)
446
+ ).squeeze(-1)
447
+
448
+ # Convert from [-1,1] to [0,1] range for storage
449
+ normal_tangent = (normal_tangent * 0.5 + 0.5).clamp(
450
+ 0, 1
451
+ )
452
+
453
+ f[bake_mask] = normal_tangent.view(-1, 3)
454
+ mat_out["bump"] = f
455
+ else:
456
+ f[bake_mask] = v.view(-1, v.shape[-1])
457
+ mat_out[k] = f
458
+
459
+ def uv_padding(arr):
460
+ if arr.ndim == 1:
461
+ return arr
462
+ return (
463
+ dilate_fill(
464
+ arr.permute(2, 0, 1)[None, ...].contiguous(),
465
+ bake_mask.unsqueeze(0).unsqueeze(0),
466
+ iterations=bake_resolution // 150,
467
+ )
468
+ .squeeze(0)
469
+ .permute(1, 2, 0)
470
+ .contiguous()
471
+ )
472
+
473
+ verts_np = convert_data(mesh.v_pos)
474
+ faces = convert_data(mesh.t_pos_idx)
475
+ uvs = convert_data(mesh.v_tex)
476
+
477
+ basecolor_tex = Image.fromarray(
478
+ float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
479
+ ).convert("RGB")
480
+ basecolor_tex.format = "JPEG"
481
+
482
+ metallic = mat_out["metallic"].squeeze().cpu().item()
483
+ roughness = mat_out["roughness"].squeeze().cpu().item()
484
+
485
+ if "bump" in mat_out and mat_out["bump"] is not None:
486
+ bump_np = convert_data(uv_padding(mat_out["bump"]))
487
+ bump_up = np.ones_like(bump_np)
488
+ bump_up[..., :2] = 0.5
489
+ bump_up[..., 2:] = 1
490
+ bump_tex = Image.fromarray(
491
+ float32_to_uint8_np(
492
+ bump_np,
493
+ dither=True,
494
+ # Do not dither if something is perfectly flat
495
+ dither_mask=np.all(
496
+ bump_np == bump_up, axis=-1, keepdims=True
497
+ ).astype(np.float32),
498
+ )
499
+ ).convert("RGB")
500
+ bump_tex.format = (
501
+ "JPEG" # PNG would be better but the assets are larger
502
+ )
503
+ else:
504
+ bump_tex = None
505
+
506
+ material = trimesh.visual.material.PBRMaterial(
507
+ baseColorTexture=basecolor_tex,
508
+ roughnessFactor=roughness,
509
+ metallicFactor=metallic,
510
+ normalTexture=bump_tex,
511
+ )
512
+
513
+ tmesh = trimesh.Trimesh(
514
+ vertices=verts_np,
515
+ faces=faces,
516
+ visual=trimesh.visual.texture.TextureVisuals(
517
+ uv=uvs, material=material
518
+ ),
519
+ )
520
+ rot = trimesh.transformations.rotation_matrix(
521
+ np.radians(-90), [1, 0, 0]
522
+ )
523
+ tmesh.apply_transform(rot)
524
+ tmesh.apply_transform(
525
+ trimesh.transformations.rotation_matrix(
526
+ np.radians(90), [0, 1, 0]
527
+ )
528
+ )
529
+
530
+ tmesh.invert()
531
+
532
+ rets.append(tmesh)
533
+
534
+ return rets, global_dict
sf3d/utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Union
3
+
4
+ import numpy as np
5
+ import rembg
6
+ import torch
7
+ import torchvision.transforms.functional as torchvision_F
8
+ from PIL import Image
9
+
10
+ import sf3d.models.utils as sf3d_utils
11
+
12
+
13
+ def get_device():
14
+ if os.environ.get("SF3D_USE_CPU", "0") == "1":
15
+ return "cpu"
16
+
17
+ device = "cpu"
18
+ if torch.cuda.is_available():
19
+ device = "cuda"
20
+ elif torch.backends.mps.is_available():
21
+ device = "mps"
22
+ return device
23
+
24
+
25
+ def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
26
+ intrinsic = sf3d_utils.get_intrinsic_from_fov(
27
+ np.deg2rad(fov_deg),
28
+ H=cond_height,
29
+ W=cond_width,
30
+ )
31
+ intrinsic_normed_cond = intrinsic.clone()
32
+ intrinsic_normed_cond[..., 0, 2] /= cond_width
33
+ intrinsic_normed_cond[..., 1, 2] /= cond_height
34
+ intrinsic_normed_cond[..., 0, 0] /= cond_width
35
+ intrinsic_normed_cond[..., 1, 1] /= cond_height
36
+
37
+ return intrinsic, intrinsic_normed_cond
38
+
39
+
40
+ def default_cond_c2w(distance: float):
41
+ c2w_cond = torch.as_tensor(
42
+ [
43
+ [0, 0, 1, distance],
44
+ [1, 0, 0, 0],
45
+ [0, 1, 0, 0],
46
+ [0, 0, 0, 1],
47
+ ]
48
+ ).float()
49
+ return c2w_cond
50
+
51
+
52
+ def remove_background(
53
+ image: Image,
54
+ rembg_session: Any = None,
55
+ force: bool = False,
56
+ **rembg_kwargs,
57
+ ) -> Image:
58
+ do_remove = True
59
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
60
+ do_remove = False
61
+ do_remove = do_remove or force
62
+ if do_remove:
63
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
64
+ return image
65
+
66
+
67
+ def get_1d_bounds(arr):
68
+ nz = np.flatnonzero(arr)
69
+ return nz[0], nz[-1]
70
+
71
+
72
+ def get_bbox_from_mask(mask, thr=0.5):
73
+ masks_for_box = (mask > thr).astype(np.float32)
74
+ assert masks_for_box.sum() > 0, "Empty mask!"
75
+ x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
76
+ y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
77
+ return x0, y0, x1, y1
78
+
79
+
80
+ def resize_foreground(
81
+ image: Union[Image.Image, np.ndarray],
82
+ ratio: float,
83
+ out_size=None,
84
+ ) -> Image:
85
+ if isinstance(image, np.ndarray):
86
+ image = Image.fromarray(image, mode="RGBA")
87
+ assert image.mode == "RGBA"
88
+ # Get bounding box
89
+ mask_np = np.array(image)[:, :, -1]
90
+ x1, y1, x2, y2 = get_bbox_from_mask(mask_np, thr=0.5)
91
+ h, w = y2 - y1, x2 - x1
92
+ yc, xc = (y1 + y2) / 2, (x1 + x2) / 2
93
+ scale = max(h, w) / ratio
94
+
95
+ new_image = torchvision_F.crop(
96
+ image,
97
+ top=int(yc - scale / 2),
98
+ left=int(xc - scale / 2),
99
+ height=int(scale),
100
+ width=int(scale),
101
+ )
102
+ if out_size is not None:
103
+ new_image = new_image.resize(out_size)
104
+
105
+ return new_image
texture_baker/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Texture baker
2
+
3
+ Small texture baker which rasterizes barycentric coordinates to a tensor.
4
+ It also implements an interpolation module which can be used to bake attributes to textures then.
5
+
6
+ ## Usage
7
+
8
+ The baker can quickly bake vertex attributes to the a texture atlas based on the UV coordinates.
9
+ It supports baking on the CPU and GPU.
10
+
11
+ ```python
12
+ from texture_baker import TextureBaker
13
+
14
+ mesh = ...
15
+ uv = mesh.uv # num_vertex, 2
16
+ triangle_idx = mesh.faces # num_faces, 3
17
+ vertices = mesh.vertices # num_vertex, 3
18
+
19
+ tb = TextureBaker()
20
+ # First get the barycentric coordinates
21
+ rast = tb.rasterize(
22
+ uv=uv, face_indices=triangle_idx, bake_resolution=1024
23
+ )
24
+ # Then interpolate vertex attributes
25
+ position_bake = tb.interpolate(attr=vertices, rast=rast, face_indices=triangle_idx)
26
+ ```
texture_baker/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ numpy
texture_baker/setup.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import platform
4
+
5
+ import torch
6
+ from setuptools import find_packages, setup
7
+ from torch.utils.cpp_extension import (
8
+ CUDA_HOME,
9
+ BuildExtension,
10
+ CppExtension,
11
+ CUDAExtension,
12
+ )
13
+
14
+ library_name = "texture_baker"
15
+
16
+
17
+ def get_extensions():
18
+ debug_mode = os.getenv("DEBUG", "0") == "1"
19
+ use_cuda = os.getenv("USE_CUDA", "1" if torch.cuda.is_available() else "0") == "1"
20
+ use_metal = (
21
+ os.getenv("USE_METAL", "1" if torch.backends.mps.is_available() else "0") == "1"
22
+ )
23
+ use_native_arch = os.getenv("USE_NATIVE_ARCH", "1") == "1"
24
+ if debug_mode:
25
+ print("Compiling in debug mode")
26
+
27
+ use_cuda = use_cuda and CUDA_HOME is not None
28
+ extension = CUDAExtension if use_cuda else CppExtension
29
+
30
+ is_hip_extension = (
31
+ True
32
+ if (
33
+ (os.environ.get("ROCM_HOME") is not None)
34
+ and (torch.version.hip is not None)
35
+ )
36
+ else False
37
+ )
38
+
39
+ extra_link_args = []
40
+ extra_compile_args = {
41
+ "cxx": (
42
+ [
43
+ "-O3" if not debug_mode else "-O0",
44
+ "-fdiagnostics-color=always",
45
+ "-fopenmp",
46
+ ]
47
+ + ["-march=native"]
48
+ if use_native_arch
49
+ else []
50
+ ),
51
+ "nvcc": [
52
+ "-O3" if not debug_mode else "-O0",
53
+ ],
54
+ }
55
+ if debug_mode:
56
+ extra_compile_args["cxx"].append("-g")
57
+ if platform.system() == "Windows":
58
+ extra_compile_args["cxx"].append("/Z7")
59
+ extra_compile_args["cxx"].append("/Od")
60
+ extra_link_args.extend(["/DEBUG"])
61
+ extra_compile_args["cxx"].append("-UNDEBUG")
62
+ extra_compile_args["nvcc"].append("-UNDEBUG")
63
+ extra_compile_args["nvcc"].append("-g")
64
+ extra_link_args.extend(["-O0", "-g"])
65
+
66
+ define_macros = []
67
+ extensions = []
68
+ libraries = []
69
+
70
+ this_dir = os.path.dirname(os.path.curdir)
71
+ sources = glob.glob(
72
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
73
+ )
74
+
75
+ if len(sources) == 0:
76
+ print("No source files found for extension, skipping extension compilation")
77
+ return None
78
+
79
+ if use_cuda:
80
+ define_macros += [
81
+ ("THRUST_IGNORE_CUB_VERSION_CHECK", None),
82
+ ]
83
+ sources += glob.glob(
84
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cu"), recursive=True
85
+ )
86
+
87
+ if not is_hip_extension:
88
+ libraries += ["cudart", "c10_cuda"]
89
+
90
+ if use_metal:
91
+ define_macros += [
92
+ ("WITH_MPS", None),
93
+ ]
94
+ sources += glob.glob(
95
+ os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True
96
+ )
97
+ extra_compile_args.update(
98
+ {"cxx": ["-O3", "-arch", "arm64", "-mmacosx-version-min=10.15"]}
99
+ )
100
+ extra_link_args += ["-arch", "arm64"]
101
+
102
+ extensions.append(
103
+ extension(
104
+ name=f"{library_name}._C",
105
+ sources=sources,
106
+ define_macros=define_macros,
107
+ extra_compile_args=extra_compile_args,
108
+ extra_link_args=extra_link_args,
109
+ libraries=libraries
110
+ + [
111
+ "c10",
112
+ "torch",
113
+ "torch_cpu",
114
+ "torch_python",
115
+ ],
116
+ )
117
+ )
118
+
119
+ for ext in extensions:
120
+ ext.libraries = ["cudart_static" if x == "cudart" else x for x in ext.libraries]
121
+
122
+ print(extensions)
123
+
124
+ return extensions
125
+
126
+
127
+ setup(
128
+ name=library_name,
129
+ version="0.0.1",
130
+ packages=find_packages(where="."),
131
+ package_dir={"": "."},
132
+ ext_modules=get_extensions(),
133
+ install_requires=[],
134
+ package_data={
135
+ library_name: [os.path.join("csrc", "*.h"), os.path.join("csrc", "*.metal")],
136
+ },
137
+ description="Small texture baker which rasterizes barycentric coordinates to a tensor.",
138
+ long_description=open("README.md").read(),
139
+ long_description_content_type="text/markdown",
140
+ url="https://github.com/Stability-AI/texture_baker",
141
+ cmdclass={"build_ext": BuildExtension},
142
+ )
texture_baker/texture_baker/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch # noqa: F401
2
+
3
+ from . import _C # noqa: F401
4
+ from .baker import TextureBaker # noqa: F401
texture_baker/texture_baker/baker.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+
5
+
6
+ class TextureBaker(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def rasterize(
11
+ self,
12
+ uv: Tensor,
13
+ face_indices: Tensor,
14
+ bake_resolution: int,
15
+ ) -> Tensor:
16
+ """
17
+ Rasterize the UV coordinates to a barycentric coordinates
18
+ & Triangle idxs texture map
19
+
20
+ Args:
21
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
22
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
23
+ bake_resolution (int): Resolution of the bake
24
+
25
+ Returns:
26
+ Tensor, bake_resolution bake_resolution 4, float: Rasterized map
27
+ """
28
+ return torch.ops.texture_baker_cpp.rasterize(
29
+ uv, face_indices.to(torch.int32), bake_resolution
30
+ )
31
+
32
+ def get_mask(self, rast: Tensor) -> Tensor:
33
+ """
34
+ Get the occupancy mask from the rasterized map
35
+
36
+ Args:
37
+ rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
38
+
39
+ Returns:
40
+ Tensor, bake_resolution bake_resolution, bool: Mask
41
+ """
42
+ return rast[..., -1] >= 0
43
+
44
+ def interpolate(
45
+ self,
46
+ attr: Tensor,
47
+ rast: Tensor,
48
+ face_indices: Tensor,
49
+ ) -> Tensor:
50
+ """
51
+ Interpolate the attributes using the rasterized map
52
+
53
+ Args:
54
+ attr (Tensor, num_vertices 3, float): Attributes of the mesh
55
+ rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
56
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
57
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
58
+
59
+ Returns:
60
+ Tensor, bake_resolution bake_resolution 3, float: Interpolated attributes
61
+ """
62
+ return torch.ops.texture_baker_cpp.interpolate(
63
+ attr, face_indices.to(torch.int32), rast
64
+ )
65
+
66
+ def forward(
67
+ self,
68
+ attr: Tensor,
69
+ uv: Tensor,
70
+ face_indices: Tensor,
71
+ bake_resolution: int,
72
+ ) -> Tensor:
73
+ """
74
+ Bake the texture
75
+
76
+ Args:
77
+ attr (Tensor, num_vertices 3, float): Attributes of the mesh
78
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
79
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
80
+ bake_resolution (int): Resolution of the bake
81
+
82
+ Returns:
83
+ Tensor, bake_resolution bake_resolution 3, float: Baked texture
84
+ """
85
+ rast = self.rasterize(uv, face_indices, bake_resolution)
86
+ return self.interpolate(attr, rast, face_indices, uv)
texture_baker/texture_baker/csrc/baker.cpp ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/Context.h>
3
+ #include <chrono>
4
+ #include <cmath>
5
+ #include <omp.h>
6
+ #include <torch/extension.h>
7
+ #ifndef __ARM_ARCH_ISA_A64
8
+ #include <immintrin.h>
9
+ #endif
10
+
11
+ #include "baker.h"
12
+
13
+ // #define TIMING
14
+ #define BINS 8
15
+
16
+ namespace texture_baker_cpp {
17
+ // Calculate the centroid of a triangle
18
+ tb_float2 triangle_centroid(const tb_float2 &v0, const tb_float2 &v1,
19
+ const tb_float2 &v2) {
20
+ return {(v0.x + v1.x + v2.x) * 0.3333f, (v0.y + v1.y + v2.y) * 0.3333f};
21
+ }
22
+
23
+ float BVH::find_best_split_plane(const BVHNode &node, int &best_axis,
24
+ int &best_pos, AABB &centroidBounds) {
25
+ float best_cost = std::numeric_limits<float>::max();
26
+
27
+ for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
28
+ {
29
+ float boundsMin = centroidBounds.min[axis];
30
+ float boundsMax = centroidBounds.max[axis];
31
+ if (boundsMin == boundsMax) {
32
+ continue;
33
+ }
34
+
35
+ // Populate the bins
36
+ float scale = BINS / (boundsMax - boundsMin);
37
+ float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
38
+ int leftSum = 0, rightSum = 0;
39
+
40
+ #ifndef __ARM_ARCH_ISA_A64
41
+ #ifndef _MSC_VER
42
+ if (__builtin_cpu_supports("sse"))
43
+ #elif (defined(_M_AMD64) || defined(_M_X64))
44
+ // SSE supported on Windows
45
+ if constexpr (true)
46
+ #endif
47
+ {
48
+ __m128 min4[BINS], max4[BINS];
49
+ unsigned int count[BINS];
50
+ for (unsigned int i = 0; i < BINS; i++)
51
+ min4[i] = _mm_set_ps1(1e30f), max4[i] = _mm_set_ps1(-1e30f),
52
+ count[i] = 0;
53
+ for (int i = node.start; i < node.end; i++) {
54
+ int tri_idx = triangle_indices[i];
55
+ const Triangle &triangle = triangles[tri_idx];
56
+
57
+ int binIdx = std::min(
58
+ BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
59
+ count[binIdx]++;
60
+ __m128 v0 = _mm_set_ps(triangle.v0.x, triangle.v0.y, 0.0f, 0.0f);
61
+ __m128 v1 = _mm_set_ps(triangle.v1.x, triangle.v1.y, 0.0f, 0.0f);
62
+ __m128 v2 = _mm_set_ps(triangle.v2.x, triangle.v2.y, 0.0f, 0.0f);
63
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
64
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
65
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
66
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
67
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
68
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
69
+ }
70
+ // gather data for the 7 planes between the 8 bins
71
+ __m128 leftMin4 = _mm_set_ps1(1e30f), rightMin4 = leftMin4;
72
+ __m128 leftMax4 = _mm_set_ps1(-1e30f), rightMax4 = leftMax4;
73
+ for (int i = 0; i < BINS - 1; i++) {
74
+ leftSum += count[i];
75
+ rightSum += count[BINS - 1 - i];
76
+ leftMin4 = _mm_min_ps(leftMin4, min4[i]);
77
+ rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
78
+ leftMax4 = _mm_max_ps(leftMax4, max4[i]);
79
+ rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
80
+ float le[4], re[4];
81
+ _mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
82
+ _mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
83
+ // SSE order goes from back to front
84
+ leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
85
+ rightCountArea[BINS - 2 - i] =
86
+ rightSum * (re[2] * re[3]); // 2D area calculation
87
+ }
88
+ }
89
+ #else
90
+ if constexpr (false) {
91
+ }
92
+ #endif
93
+ else {
94
+ struct Bin {
95
+ AABB bounds;
96
+ int triCount = 0;
97
+ } bins[BINS];
98
+
99
+ for (int i = node.start; i < node.end; i++) {
100
+ int tri_idx = triangle_indices[i];
101
+ const Triangle &triangle = triangles[tri_idx];
102
+
103
+ int binIdx = std::min(
104
+ BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
105
+ bins[binIdx].triCount++;
106
+ bins[binIdx].bounds.grow(triangle.v0);
107
+ bins[binIdx].bounds.grow(triangle.v1);
108
+ bins[binIdx].bounds.grow(triangle.v2);
109
+ }
110
+
111
+ // Gather data for the planes between the bins
112
+ AABB leftBox, rightBox;
113
+
114
+ for (int i = 0; i < BINS - 1; i++) {
115
+ leftSum += bins[i].triCount;
116
+ leftBox.grow(bins[i].bounds);
117
+ leftCountArea[i] = leftSum * leftBox.area();
118
+
119
+ rightSum += bins[BINS - 1 - i].triCount;
120
+ rightBox.grow(bins[BINS - 1 - i].bounds);
121
+ rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
122
+ }
123
+ }
124
+
125
+ // Calculate SAH cost for the planes
126
+ scale = (boundsMax - boundsMin) / BINS;
127
+ for (int i = 0; i < BINS - 1; i++) {
128
+ float planeCost = leftCountArea[i] + rightCountArea[i];
129
+ if (planeCost < best_cost) {
130
+ best_axis = axis;
131
+ best_pos = i + 1;
132
+ best_cost = planeCost;
133
+ }
134
+ }
135
+ }
136
+
137
+ return best_cost;
138
+ }
139
+
140
+ void BVH::update_node_bounds(BVHNode &node, AABB &centroidBounds) {
141
+ #ifndef __ARM_ARCH_ISA_A64
142
+ #ifndef _MSC_VER
143
+ if (__builtin_cpu_supports("sse"))
144
+ #elif (defined(_M_AMD64) || defined(_M_X64))
145
+ // SSE supported on Windows
146
+ if constexpr (true)
147
+ #endif
148
+ {
149
+ __m128 min4 = _mm_set_ps1(1e30f), max4 = _mm_set_ps1(-1e30f);
150
+ __m128 cmin4 = _mm_set_ps1(1e30f), cmax4 = _mm_set_ps1(-1e30f);
151
+
152
+ for (int i = node.start; i < node.end; i += 2) {
153
+ int tri_idx1 = triangle_indices[i];
154
+ const Triangle &leafTri1 = triangles[tri_idx1];
155
+ // Check if the second actually exists in the node
156
+ __m128 v0, v1, v2, centroid;
157
+ if (i + 1 < node.end) {
158
+ int tri_idx2 = triangle_indices[i + 1];
159
+ const Triangle leafTri2 = triangles[tri_idx2];
160
+
161
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
162
+ leafTri2.v0.y);
163
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
164
+ leafTri2.v1.y);
165
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
166
+ leafTri2.v2.y);
167
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
168
+ leafTri2.centroid.x, leafTri2.centroid.y);
169
+ } else {
170
+ // Otherwise do some duplicated work
171
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
172
+ leafTri1.v0.y);
173
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
174
+ leafTri1.v1.y);
175
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
176
+ leafTri1.v2.y);
177
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
178
+ leafTri1.centroid.x, leafTri1.centroid.y);
179
+ }
180
+
181
+ min4 = _mm_min_ps(min4, v0);
182
+ max4 = _mm_max_ps(max4, v0);
183
+ min4 = _mm_min_ps(min4, v1);
184
+ max4 = _mm_max_ps(max4, v1);
185
+ min4 = _mm_min_ps(min4, v2);
186
+ max4 = _mm_max_ps(max4, v2);
187
+ cmin4 = _mm_min_ps(cmin4, centroid);
188
+ cmax4 = _mm_max_ps(cmax4, centroid);
189
+ }
190
+
191
+ float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
192
+ _mm_store_ps(min_values, min4);
193
+ _mm_store_ps(max_values, max4);
194
+ _mm_store_ps(cmin_values, cmin4);
195
+ _mm_store_ps(cmax_values, cmax4);
196
+
197
+ node.bbox.min.x = std::min(min_values[3], min_values[1]);
198
+ node.bbox.min.y = std::min(min_values[2], min_values[0]);
199
+ node.bbox.max.x = std::max(max_values[3], max_values[1]);
200
+ node.bbox.max.y = std::max(max_values[2], max_values[0]);
201
+
202
+ centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
203
+ centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
204
+ centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
205
+ centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
206
+ }
207
+ #else
208
+ if constexpr (false) {
209
+ }
210
+ #endif
211
+ {
212
+ node.bbox.invalidate();
213
+ centroidBounds.invalidate();
214
+
215
+ // Calculate the bounding box for the node
216
+ for (int i = node.start; i < node.end; ++i) {
217
+ int tri_idx = triangle_indices[i];
218
+ const Triangle &tri = triangles[tri_idx];
219
+ node.bbox.grow(tri.v0);
220
+ node.bbox.grow(tri.v1);
221
+ node.bbox.grow(tri.v2);
222
+ centroidBounds.grow(tri.centroid);
223
+ }
224
+ }
225
+ }
226
+
227
+ void BVH::build(const tb_float2 *vertices, const tb_int3 *indices,
228
+ const int64_t &num_indices) {
229
+ #ifdef TIMING
230
+ auto start = std::chrono::high_resolution_clock::now();
231
+ #endif
232
+ // Create triangles
233
+ for (size_t i = 0; i < num_indices; ++i) {
234
+ tb_int3 idx = indices[i];
235
+ triangles.push_back(
236
+ {vertices[idx.x], vertices[idx.y], vertices[idx.z], static_cast<int>(i),
237
+ triangle_centroid(vertices[idx.x], vertices[idx.y], vertices[idx.z])});
238
+ }
239
+
240
+ // Initialize triangle_indices
241
+ triangle_indices.resize(triangles.size());
242
+ std::iota(triangle_indices.begin(), triangle_indices.end(), 0);
243
+
244
+ // Build BVH nodes
245
+ // Reserve extra capacity to fix windows specific crashes
246
+ nodes.reserve(triangles.size() * 2 + 1);
247
+ nodes.push_back({}); // Create the root node
248
+ root = 0;
249
+
250
+ // Define a struct for queue entries
251
+ struct QueueEntry {
252
+ int node_idx;
253
+ int start;
254
+ int end;
255
+ };
256
+
257
+ // Queue for breadth-first traversal
258
+ std::queue<QueueEntry> node_queue;
259
+ node_queue.push({root, 0, (int)triangles.size()});
260
+
261
+ // Process each node in the queue
262
+ while (!node_queue.empty()) {
263
+ QueueEntry current = node_queue.front();
264
+ node_queue.pop();
265
+
266
+ int node_idx = current.node_idx;
267
+ int start = current.start;
268
+ int end = current.end;
269
+
270
+ BVHNode &node = nodes[node_idx];
271
+ node.start = start;
272
+ node.end = end;
273
+
274
+ // Calculate the bounding box for the node
275
+ AABB centroidBounds;
276
+ update_node_bounds(node, centroidBounds);
277
+
278
+ // Determine the best split using SAH
279
+ int best_axis, best_pos;
280
+
281
+ float splitCost =
282
+ find_best_split_plane(node, best_axis, best_pos, centroidBounds);
283
+ float nosplitCost = node.calculate_node_cost();
284
+
285
+ // Stop condition: if the best cost is greater than or equal to the parent's
286
+ // cost
287
+ if (splitCost >= nosplitCost) {
288
+ // Leaf node
289
+ node.left = node.right = -1;
290
+ continue;
291
+ }
292
+
293
+ float scale =
294
+ BINS / (centroidBounds.max[best_axis] - centroidBounds.min[best_axis]);
295
+ int i = node.start;
296
+ int j = node.end - 1;
297
+
298
+ // Sort the triangle_indices in the range [start, end) based on the best
299
+ // axis
300
+ while (i <= j) {
301
+ // use the exact calculation we used for binning to prevent rare
302
+ // inaccuracies
303
+ int tri_idx = triangle_indices[i];
304
+ tb_float2 tcentr = triangles[tri_idx].centroid;
305
+ int binIdx = std::min(
306
+ BINS - 1,
307
+ (int)((tcentr[best_axis] - centroidBounds.min[best_axis]) * scale));
308
+ if (binIdx < best_pos)
309
+ i++;
310
+ else
311
+ std::swap(triangle_indices[i], triangle_indices[j--]);
312
+ }
313
+ int leftCount = i - node.start;
314
+ if (leftCount == 0 || leftCount == node.num_triangles()) {
315
+ // Leaf node
316
+ node.left = node.right = -1;
317
+ continue;
318
+ }
319
+
320
+ int mid = i;
321
+
322
+ // Create and set left child
323
+ node.left = nodes.size();
324
+ nodes.push_back({});
325
+ node_queue.push({node.left, start, mid});
326
+
327
+ // Create and set right child
328
+ node = nodes[node_idx]; // Update the node - Potentially stale reference
329
+ node.right = nodes.size();
330
+ nodes.push_back({});
331
+ node_queue.push({node.right, mid, end});
332
+ }
333
+ #ifdef TIMING
334
+ auto end = std::chrono::high_resolution_clock::now();
335
+ std::chrono::duration<double> elapsed = end - start;
336
+ std::cout << "BVH build time: " << elapsed.count() << "s" << std::endl;
337
+ #endif
338
+ }
339
+
340
+ // Utility function to clamp a value between a minimum and a maximum
341
+ float clamp(float val, float minVal, float maxVal) {
342
+ return std::min(std::max(val, minVal), maxVal);
343
+ }
344
+
345
+ // Function to check if a point (xy) is inside a triangle defined by vertices
346
+ // v1, v2, v3
347
+ bool barycentric_coordinates(tb_float2 xy, tb_float2 v1, tb_float2 v2,
348
+ tb_float2 v3, float &u, float &v, float &w) {
349
+ // Vectors from v1 to v2, v3 and xy
350
+ tb_float2 v1v2 = {v2.x - v1.x, v2.y - v1.y};
351
+ tb_float2 v1v3 = {v3.x - v1.x, v3.y - v1.y};
352
+ tb_float2 xyv1 = {xy.x - v1.x, xy.y - v1.y};
353
+
354
+ // Dot products of the vectors
355
+ float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
356
+ float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
357
+ float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
358
+ float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
359
+ float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
360
+
361
+ // Calculate the barycentric coordinates
362
+ float denom = d00 * d11 - d01 * d01;
363
+ v = (d11 * d20 - d01 * d21) / denom;
364
+ w = (d00 * d21 - d01 * d20) / denom;
365
+ u = 1.0f - v - w;
366
+
367
+ // Check if the point is inside the triangle
368
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
369
+ }
370
+
371
+ bool BVH::intersect(const tb_float2 &point, float &u, float &v, float &w,
372
+ int &index) const {
373
+ const int max_stack_size = 64;
374
+ int node_stack[max_stack_size];
375
+ int stack_size = 0;
376
+
377
+ node_stack[stack_size++] = root;
378
+
379
+ while (stack_size > 0) {
380
+ int node_idx = node_stack[--stack_size];
381
+ const BVHNode &node = nodes[node_idx];
382
+
383
+ if (node.is_leaf()) {
384
+ for (int i = node.start; i < node.end; ++i) {
385
+ const Triangle &tri = triangles[triangle_indices[i]];
386
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w)) {
387
+ index = tri.index;
388
+ return true;
389
+ }
390
+ }
391
+ } else {
392
+ if (nodes[node.right].bbox.overlaps(point)) {
393
+ if (stack_size < max_stack_size) {
394
+ node_stack[stack_size++] = node.right;
395
+ } else {
396
+ // Handle stack overflow
397
+ throw std::runtime_error("Node stack overflow");
398
+ }
399
+ }
400
+ if (nodes[node.left].bbox.overlaps(point)) {
401
+ if (stack_size < max_stack_size) {
402
+ node_stack[stack_size++] = node.left;
403
+ } else {
404
+ // Handle stack overflow
405
+ throw std::runtime_error("Node stack overflow");
406
+ }
407
+ }
408
+ }
409
+ }
410
+
411
+ return false;
412
+ }
413
+
414
+ torch::Tensor rasterize_cpu(torch::Tensor uv, torch::Tensor indices,
415
+ int64_t bake_resolution) {
416
+ int width = bake_resolution;
417
+ int height = bake_resolution;
418
+ int num_pixels = width * height;
419
+ torch::Tensor rast_result = torch::empty(
420
+ {bake_resolution, bake_resolution, 4},
421
+ torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
422
+
423
+ float *rast_result_ptr = rast_result.contiguous().data_ptr<float>();
424
+ const tb_float2 *vertices = (tb_float2 *)uv.data_ptr<float>();
425
+ const tb_int3 *tris = (tb_int3 *)indices.data_ptr<int>();
426
+
427
+ BVH bvh;
428
+ bvh.build(vertices, tris, indices.size(0));
429
+
430
+ #ifdef TIMING
431
+ auto start = std::chrono::high_resolution_clock::now();
432
+ #endif
433
+
434
+ #pragma omp parallel for
435
+ for (int idx = 0; idx < num_pixels; ++idx) {
436
+ int x = idx / height;
437
+ int y = idx % height;
438
+ int idx_ = idx * 4; // Note: *4 because we're storing float4 per pixel
439
+
440
+ tb_float2 pixel_coord = {float(y) / height, float(x) / width};
441
+ pixel_coord.x = clamp(pixel_coord.x, 0.0f, 1.0f);
442
+ pixel_coord.y = 1.0f - clamp(pixel_coord.y, 0.0f, 1.0f);
443
+
444
+ float u, v, w;
445
+ int triangle_idx;
446
+ if (bvh.intersect(pixel_coord, u, v, w, triangle_idx)) {
447
+ rast_result_ptr[idx_ + 0] = u;
448
+ rast_result_ptr[idx_ + 1] = v;
449
+ rast_result_ptr[idx_ + 2] = w;
450
+ rast_result_ptr[idx_ + 3] = static_cast<float>(triangle_idx);
451
+ } else {
452
+ rast_result_ptr[idx_ + 0] = 0.0f;
453
+ rast_result_ptr[idx_ + 1] = 0.0f;
454
+ rast_result_ptr[idx_ + 2] = 0.0f;
455
+ rast_result_ptr[idx_ + 3] = -1.0f;
456
+ }
457
+ }
458
+
459
+ #ifdef TIMING
460
+ auto end = std::chrono::high_resolution_clock::now();
461
+ std::chrono::duration<double> elapsed = end - start;
462
+ std::cout << "Rasterization time: " << elapsed.count() << "s" << std::endl;
463
+ #endif
464
+ return rast_result;
465
+ }
466
+
467
+ torch::Tensor interpolate_cpu(torch::Tensor attr, torch::Tensor indices,
468
+ torch::Tensor rast) {
469
+ #ifdef TIMING
470
+ auto start = std::chrono::high_resolution_clock::now();
471
+ #endif
472
+ int height = rast.size(0);
473
+ int width = rast.size(1);
474
+ torch::Tensor pos_bake = torch::empty(
475
+ {height, width, 3},
476
+ torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
477
+
478
+ const float *attr_ptr = attr.contiguous().data_ptr<float>();
479
+ const int *indices_ptr = indices.contiguous().data_ptr<int>();
480
+ const float *rast_ptr = rast.contiguous().data_ptr<float>();
481
+ float *output_ptr = pos_bake.contiguous().data_ptr<float>();
482
+
483
+ int num_pixels = width * height;
484
+
485
+ #pragma omp parallel for
486
+ for (int idx = 0; idx < num_pixels; ++idx) {
487
+ int idx_ = idx * 4; // Index into the float4 array (4 floats per pixel)
488
+ tb_float3 barycentric = {
489
+ rast_ptr[idx_ + 0],
490
+ rast_ptr[idx_ + 1],
491
+ rast_ptr[idx_ + 2],
492
+ };
493
+ int triangle_idx = static_cast<int>(rast_ptr[idx_ + 3]);
494
+
495
+ if (triangle_idx < 0) {
496
+ output_ptr[idx * 3 + 0] = 0.0f;
497
+ output_ptr[idx * 3 + 1] = 0.0f;
498
+ output_ptr[idx * 3 + 2] = 0.0f;
499
+ continue;
500
+ }
501
+
502
+ tb_int3 triangle = {indices_ptr[3 * triangle_idx + 0],
503
+ indices_ptr[3 * triangle_idx + 1],
504
+ indices_ptr[3 * triangle_idx + 2]};
505
+ tb_float3 v1 = {attr_ptr[3 * triangle.x + 0], attr_ptr[3 * triangle.x + 1],
506
+ attr_ptr[3 * triangle.x + 2]};
507
+ tb_float3 v2 = {attr_ptr[3 * triangle.y + 0], attr_ptr[3 * triangle.y + 1],
508
+ attr_ptr[3 * triangle.y + 2]};
509
+ tb_float3 v3 = {attr_ptr[3 * triangle.z + 0], attr_ptr[3 * triangle.z + 1],
510
+ attr_ptr[3 * triangle.z + 2]};
511
+
512
+ tb_float3 interpolated;
513
+ interpolated.x =
514
+ v1.x * barycentric.x + v2.x * barycentric.y + v3.x * barycentric.z;
515
+ interpolated.y =
516
+ v1.y * barycentric.x + v2.y * barycentric.y + v3.y * barycentric.z;
517
+ interpolated.z =
518
+ v1.z * barycentric.x + v2.z * barycentric.y + v3.z * barycentric.z;
519
+
520
+ output_ptr[idx * 3 + 0] = interpolated.x;
521
+ output_ptr[idx * 3 + 1] = interpolated.y;
522
+ output_ptr[idx * 3 + 2] = interpolated.z;
523
+ }
524
+
525
+ #ifdef TIMING
526
+ auto end = std::chrono::high_resolution_clock::now();
527
+ std::chrono::duration<double> elapsed = end - start;
528
+ std::cout << "Interpolation time: " << elapsed.count() << "s" << std::endl;
529
+ #endif
530
+ return pos_bake;
531
+ }
532
+
533
+ // Registers _C as a Python extension module.
534
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
535
+
536
+ // Defines the operators
537
+ TORCH_LIBRARY(texture_baker_cpp, m) {
538
+ m.def("rasterize(Tensor uv, Tensor indices, int bake_resolution) -> Tensor");
539
+ m.def("interpolate(Tensor attr, Tensor indices, Tensor rast) -> Tensor");
540
+ }
541
+
542
+ // Registers CPP implementations
543
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, CPU, m) {
544
+ m.impl("rasterize", &rasterize_cpu);
545
+ m.impl("interpolate", &interpolate_cpu);
546
+ }
547
+
548
+ } // namespace texture_baker_cpp
texture_baker/texture_baker/csrc/baker.h ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if defined(__NVCC__) || defined(__HIPCC__) || defined(__METAL__)
4
+ #define CUDA_ENABLED
5
+ #ifndef __METAL__
6
+ #define CUDA_HOST_DEVICE __host__ __device__
7
+ #define CUDA_DEVICE __device__
8
+ #define METAL_CONSTANT_MEM
9
+ #define METAL_THREAD_MEM
10
+ #else
11
+ #define tb_float2 float2
12
+ #define CUDA_HOST_DEVICE
13
+ #define CUDA_DEVICE
14
+ #define METAL_CONSTANT_MEM constant
15
+ #define METAL_THREAD_MEM thread
16
+ #endif
17
+ #else
18
+ #define CUDA_HOST_DEVICE
19
+ #define CUDA_DEVICE
20
+ #define METAL_CONSTANT_MEM
21
+ #define METAL_THREAD_MEM
22
+ #include <cfloat>
23
+ #include <limits>
24
+ #include <vector>
25
+ #endif
26
+
27
+ namespace texture_baker_cpp {
28
+ // Structure to represent a 2D point or vector
29
+ #ifndef __METAL__
30
+ union alignas(8) tb_float2 {
31
+ struct {
32
+ float x, y;
33
+ };
34
+
35
+ float data[2];
36
+
37
+ float &operator[](size_t idx) {
38
+ if (idx > 1)
39
+ throw std::runtime_error("bad index");
40
+ return data[idx];
41
+ }
42
+
43
+ const float &operator[](size_t idx) const {
44
+ if (idx > 1)
45
+ throw std::runtime_error("bad index");
46
+ return data[idx];
47
+ }
48
+
49
+ bool operator==(const tb_float2 &rhs) const {
50
+ return x == rhs.x && y == rhs.y;
51
+ }
52
+ };
53
+
54
+ union alignas(4) tb_float3 {
55
+ struct {
56
+ float x, y, z;
57
+ };
58
+
59
+ float data[3];
60
+
61
+ float &operator[](size_t idx) {
62
+ if (idx > 2)
63
+ throw std::runtime_error("bad index");
64
+ return data[idx];
65
+ }
66
+
67
+ const float &operator[](size_t idx) const {
68
+ if (idx > 2)
69
+ throw std::runtime_error("bad index");
70
+ return data[idx];
71
+ }
72
+ };
73
+
74
+ union alignas(16) tb_float4 {
75
+ struct {
76
+ float x, y, z, w;
77
+ };
78
+
79
+ float data[4];
80
+
81
+ float &operator[](size_t idx) {
82
+ if (idx > 3)
83
+ throw std::runtime_error("bad index");
84
+ return data[idx];
85
+ }
86
+
87
+ const float &operator[](size_t idx) const {
88
+ if (idx > 3)
89
+ throw std::runtime_error("bad index");
90
+ return data[idx];
91
+ }
92
+ };
93
+ #endif
94
+
95
+ union alignas(4) tb_int3 {
96
+ struct {
97
+ int x, y, z;
98
+ };
99
+
100
+ int data[3];
101
+ #ifndef __METAL__
102
+ int &operator[](size_t idx) {
103
+ if (idx > 2)
104
+ throw std::runtime_error("bad index");
105
+ return data[idx];
106
+ }
107
+ #endif
108
+ };
109
+
110
+ // BVH structure to accelerate point-triangle intersection
111
+ struct alignas(16) AABB {
112
+ // Init bounding boxes with max/min
113
+ tb_float2 min = {FLT_MAX, FLT_MAX};
114
+ tb_float2 max = {FLT_MIN, FLT_MIN};
115
+
116
+ #ifndef CUDA_ENABLED
117
+ // grow the AABB to include a point
118
+ void grow(const tb_float2 &p) {
119
+ min.x = std::min(min.x, p.x);
120
+ min.y = std::min(min.y, p.y);
121
+ max.x = std::max(max.x, p.x);
122
+ max.y = std::max(max.y, p.y);
123
+ }
124
+
125
+ void grow(const AABB &b) {
126
+ if (b.min.x != FLT_MAX) {
127
+ grow(b.min);
128
+ grow(b.max);
129
+ }
130
+ }
131
+ #endif
132
+
133
+ // Check if two AABBs overlap
134
+ bool overlaps(const METAL_THREAD_MEM AABB &other) const {
135
+ return min.x <= other.max.x && max.x >= other.min.x &&
136
+ min.y <= other.max.y && max.y >= other.min.y;
137
+ }
138
+
139
+ bool overlaps(const METAL_THREAD_MEM tb_float2 &point) const {
140
+ return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
141
+ point.y <= max.y;
142
+ }
143
+
144
+ #if defined(__NVCC__) || defined(__HIPCC__)
145
+ CUDA_DEVICE bool overlaps(const float2 &point) const {
146
+ return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
147
+ point.y <= max.y;
148
+ }
149
+ #endif
150
+
151
+ // Initialize AABB to an invalid state
152
+ void invalidate() {
153
+ min = {FLT_MAX, FLT_MAX};
154
+ max = {FLT_MIN, FLT_MIN};
155
+ }
156
+
157
+ // Calculate the area of the AABB
158
+ float area() const {
159
+ tb_float2 extent = {max.x - min.x, max.y - min.y};
160
+ return extent.x * extent.y;
161
+ }
162
+ };
163
+
164
+ struct BVHNode {
165
+ AABB bbox;
166
+ int start, end;
167
+ int left, right;
168
+
169
+ int num_triangles() const { return end - start; }
170
+
171
+ CUDA_HOST_DEVICE bool is_leaf() const { return left == -1 && right == -1; }
172
+
173
+ float calculate_node_cost() {
174
+ float area = bbox.area();
175
+ return num_triangles() * area;
176
+ }
177
+ };
178
+
179
+ struct Triangle {
180
+ tb_float2 v0, v1, v2;
181
+ int index;
182
+ tb_float2 centroid;
183
+ };
184
+
185
+ #ifndef __METAL__
186
+ struct BVH {
187
+ std::vector<BVHNode> nodes;
188
+ std::vector<Triangle> triangles;
189
+ std::vector<int> triangle_indices;
190
+ int root;
191
+
192
+ void build(const tb_float2 *vertices, const tb_int3 *indices,
193
+ const int64_t &num_indices);
194
+ bool intersect(const tb_float2 &point, float &u, float &v, float &w,
195
+ int &index) const;
196
+
197
+ void update_node_bounds(BVHNode &node, AABB &centroidBounds);
198
+ float find_best_split_plane(const BVHNode &node, int &best_axis,
199
+ int &best_pos, AABB &centroidBounds);
200
+ };
201
+ #endif
202
+
203
+ } // namespace texture_baker_cpp
texture_baker/texture_baker/csrc/baker_kernel.cu ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/Context.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <torch/extension.h>
5
+
6
+ #include "baker.h"
7
+
8
+ // #define TIMING
9
+
10
+ #define STRINGIFY(x) #x
11
+ #define STR(x) STRINGIFY(x)
12
+ #define FILE_LINE __FILE__ ":" STR(__LINE__)
13
+ #define CUDA_CHECK_THROW(x) \
14
+ do { \
15
+ cudaError_t _result = x; \
16
+ if (_result != cudaSuccess) \
17
+ throw std::runtime_error(std::string(FILE_LINE " check failed " #x " failed: ") + cudaGetErrorString(_result)); \
18
+ } while(0)
19
+
20
+ #if defined(__HIPCC__)
21
+ #define cudaMallocAsync hipMallocAsync
22
+ #define cudaFreeAsync hipFreeAsync
23
+ #endif
24
+
25
+ namespace texture_baker_cpp
26
+ {
27
+
28
+ __device__ float3 operator+(const float3 &a, const float3 &b)
29
+ {
30
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
31
+ }
32
+
33
+ // xy: 2D test position
34
+ // v1: vertex position 1
35
+ // v2: vertex position 2
36
+ // v3: vertex position 3
37
+ //
38
+ __forceinline__ __device__ bool barycentric_coordinates(const float2 &xy, const tb_float2 &v1, const tb_float2 &v2, const tb_float2 &v3, float &u, float &v, float &w)
39
+ {
40
+ // Return true if the point (xy) is inside the triangle defined by the vertices v1, v2, v3.
41
+ // If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
42
+ float2 v1v2 = make_float2(v2.x - v1.x, v2.y - v1.y);
43
+ float2 v1v3 = make_float2(v3.x - v1.x, v3.y - v1.y);
44
+ float2 xyv1 = make_float2(xy.x - v1.x, xy.y - v1.y);
45
+
46
+ float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
47
+ float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
48
+ float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
49
+ float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
50
+ float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
51
+
52
+ float denom = d00 * d11 - d01 * d01;
53
+ v = (d11 * d20 - d01 * d21) / denom;
54
+ w = (d00 * d21 - d01 * d20) / denom;
55
+ u = 1.0f - v - w;
56
+
57
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
58
+ }
59
+
60
+ __global__ void kernel_interpolate(const float3* __restrict__ attr, const int3* __restrict__ indices, const float4* __restrict__ rast, float3* __restrict__ output, int width, int height)
61
+ {
62
+ // Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
63
+ //int idx = x * width + y;
64
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
65
+ int x = idx / width;
66
+ int y = idx % width;
67
+
68
+ if (x >= width || y >= height)
69
+ return;
70
+
71
+ float4 barycentric = rast[idx];
72
+ int triangle_idx = int(barycentric.w);
73
+
74
+ if (triangle_idx < 0)
75
+ {
76
+ output[idx] = make_float3(0.0f, 0.0f, 0.0f);
77
+ return;
78
+ }
79
+
80
+ float3 v1 = attr[indices[triangle_idx].x];
81
+ float3 v2 = attr[indices[triangle_idx].y];
82
+ float3 v3 = attr[indices[triangle_idx].z];
83
+
84
+ output[idx] = make_float3(v1.x * barycentric.x, v1.y * barycentric.x, v1.z * barycentric.x)
85
+ + make_float3(v2.x * barycentric.y, v2.y * barycentric.y, v2.z * barycentric.y)
86
+ + make_float3(v3.x * barycentric.z, v3.y * barycentric.z, v3.z * barycentric.z);
87
+ }
88
+
89
+ __device__ bool bvh_intersect(
90
+ const BVHNode* __restrict__ nodes,
91
+ const Triangle* __restrict__ triangles,
92
+ const int* __restrict__ triangle_indices,
93
+ const int root,
94
+ const float2 &point,
95
+ float &u, float &v, float &w,
96
+ int &index)
97
+ {
98
+ constexpr int max_stack_size = 64;
99
+ int node_stack[max_stack_size];
100
+ int stack_size = 0;
101
+
102
+ node_stack[stack_size++] = root;
103
+
104
+ while (stack_size > 0)
105
+ {
106
+ int node_idx = node_stack[--stack_size];
107
+ const BVHNode &node = nodes[node_idx];
108
+
109
+ if (node.is_leaf())
110
+ {
111
+ for (int i = node.start; i < node.end; ++i)
112
+ {
113
+ const Triangle &tri = triangles[triangle_indices[i]];
114
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
115
+ {
116
+ index = tri.index;
117
+ return true;
118
+ }
119
+ }
120
+ }
121
+ else
122
+ {
123
+ if (nodes[node.right].bbox.overlaps(point))
124
+ {
125
+ if (stack_size < max_stack_size)
126
+ {
127
+ node_stack[stack_size++] = node.right;
128
+ }
129
+ else
130
+ {
131
+ // Handle stack overflow
132
+ // Make sure NDEBUG is not defined (see setup.py)
133
+ assert(0 && "Node stack overflow");
134
+ }
135
+ }
136
+ if (nodes[node.left].bbox.overlaps(point))
137
+ {
138
+ if (stack_size < max_stack_size)
139
+ {
140
+ node_stack[stack_size++] = node.left;
141
+ }
142
+ else
143
+ {
144
+ // Handle stack overflow
145
+ // Make sure NDEBUG is not defined (see setup.py)
146
+ assert(0 && "Node stack overflow");
147
+ }
148
+ }
149
+ }
150
+ }
151
+
152
+ return false;
153
+ }
154
+
155
+ __global__ void kernel_bake_uv(
156
+ float2* __restrict__ uv,
157
+ int3* __restrict__ indices,
158
+ float4* __restrict__ output,
159
+ const BVHNode* __restrict__ nodes,
160
+ const Triangle* __restrict__ triangles,
161
+ const int* __restrict__ triangle_indices,
162
+ const int root,
163
+ const int width,
164
+ const int height,
165
+ const int num_indices)
166
+ {
167
+ //int idx = x * width + y;
168
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
169
+ int x = idx / width;
170
+ int y = idx % width;
171
+
172
+ if (y >= width || x >= height)
173
+ return;
174
+
175
+ // We index x,y but the original coords are HW. So swap them
176
+ float2 pixel_coord = make_float2(float(y) / height, float(x) / width);
177
+ pixel_coord.x = fminf(fmaxf(pixel_coord.x, 0.0f), 1.0f);
178
+ pixel_coord.y = 1.0f - fminf(fmaxf(pixel_coord.y, 0.0f), 1.0f);
179
+
180
+ float u, v, w;
181
+ int triangle_idx;
182
+ bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
183
+
184
+ if (hit)
185
+ {
186
+ output[idx] = make_float4(u, v, w, float(triangle_idx));
187
+ return;
188
+ }
189
+
190
+ output[idx] = make_float4(0.0f, 0.0f, 0.0f, -1.0f);
191
+ }
192
+
193
+ torch::Tensor rasterize_gpu(
194
+ torch::Tensor uv,
195
+ torch::Tensor indices,
196
+ int64_t bake_resolution)
197
+ {
198
+ #ifdef TIMING
199
+ auto start = std::chrono::high_resolution_clock::now();
200
+ #endif
201
+ constexpr int block_size = 16 * 16;
202
+ int grid_size = bake_resolution * bake_resolution / block_size;
203
+ dim3 block_dims(block_size, 1, 1);
204
+ dim3 grid_dims(grid_size, 1, 1);
205
+
206
+ int num_indices = indices.size(0);
207
+
208
+ int width = bake_resolution;
209
+ int height = bake_resolution;
210
+
211
+ // Step 1: create an empty tensor to store the output.
212
+ torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
213
+
214
+ auto vertices_cpu = uv.contiguous().cpu();
215
+ auto indices_cpu = indices.contiguous().cpu();
216
+
217
+ const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
218
+ const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
219
+
220
+ BVH bvh;
221
+ bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
222
+
223
+ BVHNode *nodes_gpu = nullptr;
224
+ Triangle *triangles_gpu = nullptr;
225
+ int *triangle_indices_gpu = nullptr;
226
+ const int bvh_root = bvh.root;
227
+ cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
228
+
229
+ CUDA_CHECK_THROW(cudaMallocAsync(&nodes_gpu, sizeof(BVHNode) * bvh.nodes.size(), cuda_stream));
230
+ CUDA_CHECK_THROW(cudaMallocAsync(&triangles_gpu, sizeof(Triangle) * bvh.triangles.size(), cuda_stream));
231
+ CUDA_CHECK_THROW(cudaMallocAsync(&triangle_indices_gpu, sizeof(int) * bvh.triangle_indices.size(), cuda_stream));
232
+
233
+ CUDA_CHECK_THROW(cudaMemcpyAsync(nodes_gpu, bvh.nodes.data(), sizeof(BVHNode) * bvh.nodes.size(), cudaMemcpyHostToDevice, cuda_stream));
234
+ CUDA_CHECK_THROW(cudaMemcpyAsync(triangles_gpu, bvh.triangles.data(), sizeof(Triangle) * bvh.triangles.size(), cudaMemcpyHostToDevice, cuda_stream));
235
+ CUDA_CHECK_THROW(cudaMemcpyAsync(triangle_indices_gpu, bvh.triangle_indices.data(), sizeof(int) * bvh.triangle_indices.size(), cudaMemcpyHostToDevice, cuda_stream));
236
+
237
+ kernel_bake_uv<<<grid_dims, block_dims, 0, cuda_stream>>>(
238
+ (float2 *)uv.contiguous().data_ptr<float>(),
239
+ (int3 *)indices.contiguous().data_ptr<int>(),
240
+ (float4 *)rast_result.contiguous().data_ptr<float>(),
241
+ nodes_gpu,
242
+ triangles_gpu,
243
+ triangle_indices_gpu,
244
+ bvh_root,
245
+ width,
246
+ height,
247
+ num_indices);
248
+
249
+ CUDA_CHECK_THROW(cudaFreeAsync(nodes_gpu, cuda_stream));
250
+ CUDA_CHECK_THROW(cudaFreeAsync(triangles_gpu, cuda_stream));
251
+ CUDA_CHECK_THROW(cudaFreeAsync(triangle_indices_gpu, cuda_stream));
252
+
253
+ #ifdef TIMING
254
+ CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
255
+ auto end = std::chrono::high_resolution_clock::now();
256
+ std::chrono::duration<double> elapsed = end - start;
257
+ std::cout << "Rasterization time (CUDA): " << elapsed.count() << "s" << std::endl;
258
+ #endif
259
+ return rast_result;
260
+ }
261
+
262
+ torch::Tensor interpolate_gpu(
263
+ torch::Tensor attr,
264
+ torch::Tensor indices,
265
+ torch::Tensor rast)
266
+ {
267
+ #ifdef TIMING
268
+ auto start = std::chrono::high_resolution_clock::now();
269
+ #endif
270
+ constexpr int block_size = 16 * 16;
271
+ int grid_size = rast.size(0) * rast.size(0) / block_size;
272
+ dim3 block_dims(block_size, 1, 1);
273
+ dim3 grid_dims(grid_size, 1, 1);
274
+
275
+ // Step 1: create an empty tensor to store the output.
276
+ torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
277
+
278
+ int width = rast.size(0);
279
+ int height = rast.size(1);
280
+
281
+ cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
282
+
283
+ kernel_interpolate<<<grid_dims, block_dims, 0, cuda_stream>>>(
284
+ (float3 *)attr.contiguous().data_ptr<float>(),
285
+ (int3 *)indices.contiguous().data_ptr<int>(),
286
+ (float4 *)rast.contiguous().data_ptr<float>(),
287
+ (float3 *)pos_bake.contiguous().data_ptr<float>(),
288
+ width,
289
+ height);
290
+ #ifdef TIMING
291
+ CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
292
+ auto end = std::chrono::high_resolution_clock::now();
293
+ std::chrono::duration<double> elapsed = end - start;
294
+ std::cout << "Interpolation time (CUDA): " << elapsed.count() << "s" << std::endl;
295
+ #endif
296
+ return pos_bake;
297
+ }
298
+
299
+ // Registers CUDA implementations
300
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, CUDA, m)
301
+ {
302
+ m.impl("rasterize", &rasterize_gpu);
303
+ m.impl("interpolate", &interpolate_gpu);
304
+ }
305
+
306
+ }
texture_baker/texture_baker/csrc/baker_kernel.metal ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ // This header is inlined manually
5
+ //#include "baker.h"
6
+
7
+ // Use the texture_baker_cpp so it can use the classes from baker.h
8
+ using namespace texture_baker_cpp;
9
+
10
+ // Utility function to compute barycentric coordinates
11
+ bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, thread float &u, thread float &v, thread float &w) {
12
+ float2 v1v2 = v2 - v1;
13
+ float2 v1v3 = v3 - v1;
14
+ float2 xyv1 = xy - v1;
15
+
16
+ float d00 = dot(v1v2, v1v2);
17
+ float d01 = dot(v1v2, v1v3);
18
+ float d11 = dot(v1v3, v1v3);
19
+ float d20 = dot(xyv1, v1v2);
20
+ float d21 = dot(xyv1, v1v3);
21
+
22
+ float denom = d00 * d11 - d01 * d01;
23
+ v = (d11 * d20 - d01 * d21) / denom;
24
+ w = (d00 * d21 - d01 * d20) / denom;
25
+ u = 1.0f - v - w;
26
+
27
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
28
+ }
29
+
30
+ // Kernel function for interpolation
31
+ kernel void kernel_interpolate(constant packed_float3 *attr [[buffer(0)]],
32
+ constant packed_int3 *indices [[buffer(1)]],
33
+ constant packed_float4 *rast [[buffer(2)]],
34
+ device packed_float3 *output [[buffer(3)]],
35
+ constant int &width [[buffer(4)]],
36
+ constant int &height [[buffer(5)]],
37
+ uint3 blockIdx [[threadgroup_position_in_grid]],
38
+ uint3 threadIdx [[thread_position_in_threadgroup]],
39
+ uint3 blockDim [[threads_per_threadgroup]])
40
+ {
41
+ // Calculate global position using threadgroup and thread positions
42
+ int x = blockIdx.x * blockDim.x + threadIdx.x;
43
+ int y = blockIdx.y * blockDim.y + threadIdx.y;
44
+
45
+ if (x >= width || y >= height) return;
46
+
47
+ int idx = y * width + x;
48
+ float4 barycentric = rast[idx];
49
+ int triangle_idx = int(barycentric.w);
50
+
51
+ if (triangle_idx < 0) {
52
+ output[idx] = float3(0.0f, 0.0f, 0.0f);
53
+ return;
54
+ }
55
+
56
+ float3 v1 = attr[indices[triangle_idx].x];
57
+ float3 v2 = attr[indices[triangle_idx].y];
58
+ float3 v3 = attr[indices[triangle_idx].z];
59
+
60
+ output[idx] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
61
+ }
62
+
63
+ bool bvh_intersect(
64
+ constant BVHNode* nodes,
65
+ constant Triangle* triangles,
66
+ constant int* triangle_indices,
67
+ const thread int root,
68
+ const thread float2 &point,
69
+ thread float &u, thread float &v, thread float &w,
70
+ thread int &index)
71
+ {
72
+ const int max_stack_size = 64;
73
+ thread int node_stack[max_stack_size];
74
+ int stack_size = 0;
75
+
76
+ node_stack[stack_size++] = root;
77
+
78
+ while (stack_size > 0)
79
+ {
80
+ int node_idx = node_stack[--stack_size];
81
+ BVHNode node = nodes[node_idx];
82
+
83
+ if (node.is_leaf())
84
+ {
85
+ for (int i = node.start; i < node.end; ++i)
86
+ {
87
+ constant Triangle &tri = triangles[triangle_indices[i]];
88
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
89
+ {
90
+ index = tri.index;
91
+ return true;
92
+ }
93
+ }
94
+ }
95
+ else
96
+ {
97
+ BVHNode test_node = nodes[node.right];
98
+ if (test_node.bbox.overlaps(point))
99
+ {
100
+ if (stack_size < max_stack_size)
101
+ {
102
+ node_stack[stack_size++] = node.right;
103
+ }
104
+ else
105
+ {
106
+ // Handle stack overflow
107
+ // Sadly, metal doesn't support asserts (but you could try enabling metal validation layers)
108
+ return false;
109
+ }
110
+ }
111
+ test_node = nodes[node.left];
112
+ if (test_node.bbox.overlaps(point))
113
+ {
114
+ if (stack_size < max_stack_size)
115
+ {
116
+ node_stack[stack_size++] = node.left;
117
+ }
118
+ else
119
+ {
120
+ // Handle stack overflow
121
+ return false;
122
+ }
123
+ }
124
+ }
125
+ }
126
+
127
+ return false;
128
+ }
129
+
130
+
131
+ // Kernel function for baking UV
132
+ kernel void kernel_bake_uv(constant packed_float2 *uv [[buffer(0)]],
133
+ constant packed_int3 *indices [[buffer(1)]],
134
+ device packed_float4 *output [[buffer(2)]],
135
+ constant BVHNode *nodes [[buffer(3)]],
136
+ constant Triangle *triangles [[buffer(4)]],
137
+ constant int *triangle_indices [[buffer(5)]],
138
+ constant int &root [[buffer(6)]],
139
+ constant int &width [[buffer(7)]],
140
+ constant int &height [[buffer(8)]],
141
+ constant int &num_indices [[buffer(9)]],
142
+ uint3 blockIdx [[threadgroup_position_in_grid]],
143
+ uint3 threadIdx [[thread_position_in_threadgroup]],
144
+ uint3 blockDim [[threads_per_threadgroup]])
145
+ {
146
+ // Calculate global position using threadgroup and thread positions
147
+ int x = blockIdx.x * blockDim.x + threadIdx.x;
148
+ int y = blockIdx.y * blockDim.y + threadIdx.y;
149
+
150
+
151
+ if (x >= width || y >= height) return;
152
+
153
+ int idx = x * width + y;
154
+
155
+ // Swap original coordinates
156
+ float2 pixel_coord = float2(float(y) / float(height), float(x) / float(width));
157
+ pixel_coord = clamp(pixel_coord, 0.0f, 1.0f);
158
+ pixel_coord.y = 1.0f - pixel_coord.y;
159
+
160
+ float u, v, w;
161
+ int triangle_idx;
162
+ bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
163
+
164
+ if (hit) {
165
+ output[idx] = float4(u, v, w, float(triangle_idx));
166
+ return;
167
+ }
168
+
169
+ output[idx] = float4(0.0f, 0.0f, 0.0f, -1.0f);
170
+ }
texture_baker/texture_baker/csrc/baker_kernel.mm ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/Context.h>
4
+ #include "baker.h"
5
+
6
+ #import <Foundation/Foundation.h>
7
+ #import <Metal/Metal.h>
8
+ #include <filesystem>
9
+
10
+ // Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
11
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
12
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
13
+ }
14
+
15
+ // Helper function to create a compute pipeline state object (PSO).
16
+ static inline id<MTLComputePipelineState> createComputePipelineState(id<MTLDevice> device, NSString* fullSource, std::string kernel_name) {
17
+ NSError *error = nil;
18
+
19
+ // Load the custom kernel shader.
20
+ MTLCompileOptions *options = [[MTLCompileOptions alloc] init];
21
+ // Add the preprocessor macro "__METAL__"
22
+ options.preprocessorMacros = @{@"__METAL__": @""};
23
+ id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: fullSource options:options error:&error];
24
+ TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String);
25
+
26
+ id<MTLFunction> customKernelFunction = [customKernelLibrary newFunctionWithName:[NSString stringWithUTF8String:kernel_name.c_str()]];
27
+ TORCH_CHECK(customKernelFunction, "Failed to create function state object for ", kernel_name.c_str());
28
+
29
+ id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:customKernelFunction error:&error];
30
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
31
+
32
+ return pso;
33
+ }
34
+
35
+ std::filesystem::path get_extension_path() {
36
+ // Ensure the GIL is held before calling any Python C API function
37
+ PyGILState_STATE gstate = PyGILState_Ensure();
38
+
39
+ const char* module_name = "texture_baker";
40
+
41
+ // Import the module by name
42
+ PyObject* module = PyImport_ImportModule(module_name);
43
+ if (!module) {
44
+ PyGILState_Release(gstate);
45
+ throw std::runtime_error("Could not import the module: " + std::string(module_name));
46
+ }
47
+
48
+ // Get the filename of the module
49
+ PyObject* filename_obj = PyModule_GetFilenameObject(module);
50
+ if (filename_obj) {
51
+ std::string path = PyUnicode_AsUTF8(filename_obj);
52
+ Py_DECREF(filename_obj);
53
+ PyGILState_Release(gstate);
54
+
55
+ // Get the directory part of the path (removing the __init__.py)
56
+ std::filesystem::path module_path = std::filesystem::path(path).parent_path();
57
+
58
+ // Append the 'csrc' directory to the path
59
+ module_path /= "csrc";
60
+
61
+ return module_path;
62
+ } else {
63
+ PyGILState_Release(gstate);
64
+ throw std::runtime_error("Could not retrieve the module filename.");
65
+ }
66
+ }
67
+
68
+ NSString *get_shader_sources_as_string()
69
+ {
70
+ const std::filesystem::path csrc_path = get_extension_path();
71
+ const std::string shader_path = (csrc_path / "baker_kernel.metal").string();
72
+ const std::string shader_header_path = (csrc_path / "baker.h").string();
73
+ // Load the Metal shader from the specified path
74
+ NSError *error = nil;
75
+
76
+ NSString* shaderHeaderSource = [
77
+ NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_header_path.c_str()]
78
+ encoding:NSUTF8StringEncoding
79
+ error:&error];
80
+ if (error) {
81
+ throw std::runtime_error("Failed to load baker.h: " + std::string(error.localizedDescription.UTF8String));
82
+ }
83
+
84
+ NSString* shaderSource = [
85
+ NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_path.c_str()]
86
+ encoding:NSUTF8StringEncoding
87
+ error:&error];
88
+ if (error) {
89
+ throw std::runtime_error("Failed to load Metal shader: " + std::string(error.localizedDescription.UTF8String));
90
+ }
91
+
92
+ NSString *fullSource = [shaderHeaderSource stringByAppendingString:shaderSource];
93
+
94
+ return fullSource;
95
+ }
96
+
97
+ namespace texture_baker_cpp
98
+ {
99
+ torch::Tensor rasterize_gpu(
100
+ torch::Tensor uv,
101
+ torch::Tensor indices,
102
+ int64_t bake_resolution)
103
+ {
104
+ TORCH_CHECK(uv.device().is_mps(), "uv must be a MPS tensor");
105
+ TORCH_CHECK(uv.is_contiguous(), "uv must be contiguous");
106
+ TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
107
+
108
+ TORCH_CHECK(uv.scalar_type() == torch::kFloat32, "Unsupported data type: ", indices.scalar_type());
109
+ TORCH_CHECK(indices.scalar_type() == torch::kInt32, "Unsupported data type: ", indices.scalar_type());
110
+
111
+ torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
112
+
113
+ @autoreleasepool {
114
+ auto vertices_cpu = uv.contiguous().cpu();
115
+ auto indices_cpu = indices.contiguous().cpu();
116
+
117
+ const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
118
+ const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
119
+
120
+ BVH bvh;
121
+ bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
122
+
123
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
124
+
125
+ NSString *fullSource = get_shader_sources_as_string();
126
+
127
+ // Create a compute pipeline state object using the helper function
128
+ id<MTLComputePipelineState> bake_uv_PSO = createComputePipelineState(device, fullSource, "kernel_bake_uv");
129
+
130
+ // Get a reference to the command buffer for the MPS stream.
131
+ id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
132
+ TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
133
+
134
+ // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
135
+ dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
136
+
137
+ dispatch_sync(serialQueue, ^(){
138
+ // Start a compute pass.
139
+ id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
140
+ TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
141
+
142
+ // Get Metal buffers directly from PyTorch tensors
143
+ auto uv_buf = getMTLBufferStorage(uv.contiguous());
144
+ auto indices_buf = getMTLBufferStorage(indices.contiguous());
145
+ auto rast_result_buf = getMTLBufferStorage(rast_result);
146
+
147
+ const int width = bake_resolution;
148
+ const int height = bake_resolution;
149
+ const int num_indices = indices.size(0);
150
+ const int bvh_root = bvh.root;
151
+
152
+ // Wrap the existing CPU memory in Metal buffers with shared memory
153
+ id<MTLBuffer> nodesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.nodes.data() length:sizeof(BVHNode) * bvh.nodes.size() options:MTLResourceStorageModeShared deallocator:nil];
154
+ id<MTLBuffer> trianglesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangles.data() length:sizeof(Triangle) * bvh.triangles.size() options:MTLResourceStorageModeShared deallocator:nil];
155
+ id<MTLBuffer> triangleIndicesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangle_indices.data() length:sizeof(int) * bvh.triangle_indices.size() options:MTLResourceStorageModeShared deallocator:nil];
156
+
157
+ [computeEncoder setComputePipelineState:bake_uv_PSO];
158
+ [computeEncoder setBuffer:uv_buf offset:uv.storage_offset() * uv.element_size() atIndex:0];
159
+ [computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
160
+ [computeEncoder setBuffer:rast_result_buf offset:rast_result.storage_offset() * rast_result.element_size() atIndex:2];
161
+ [computeEncoder setBuffer:nodesBuffer offset:0 atIndex:3];
162
+ [computeEncoder setBuffer:trianglesBuffer offset:0 atIndex:4];
163
+ [computeEncoder setBuffer:triangleIndicesBuffer offset:0 atIndex:5];
164
+ [computeEncoder setBytes:&bvh_root length:sizeof(int) atIndex:6];
165
+ [computeEncoder setBytes:&width length:sizeof(int) atIndex:7];
166
+ [computeEncoder setBytes:&height length:sizeof(int) atIndex:8];
167
+ [computeEncoder setBytes:&num_indices length:sizeof(int) atIndex:9];
168
+
169
+ // Calculate a thread group size.
170
+ int block_size = 16;
171
+ MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
172
+ MTLSize numThreadgroups = MTLSizeMake(bake_resolution / block_size, bake_resolution / block_size, 1);
173
+
174
+ // Encode the compute command.
175
+ [computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
176
+ [computeEncoder endEncoding];
177
+
178
+ // Commit the work.
179
+ torch::mps::commit();
180
+ });
181
+ }
182
+
183
+ return rast_result;
184
+ }
185
+
186
+ torch::Tensor interpolate_gpu(
187
+ torch::Tensor attr,
188
+ torch::Tensor indices,
189
+ torch::Tensor rast)
190
+ {
191
+ TORCH_CHECK(attr.is_contiguous(), "attr must be contiguous");
192
+ TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
193
+ TORCH_CHECK(rast.is_contiguous(), "rast must be contiguous");
194
+
195
+ torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
196
+ std::filesystem::path csrc_path = get_extension_path();
197
+
198
+ @autoreleasepool {
199
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
200
+
201
+ NSString *fullSource = get_shader_sources_as_string();
202
+ // Create a compute pipeline state object using the helper function
203
+ id<MTLComputePipelineState> interpolate_PSO = createComputePipelineState(device, fullSource, "kernel_interpolate");
204
+
205
+ // Get a reference to the command buffer for the MPS stream.
206
+ id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
207
+ TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
208
+
209
+ // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
210
+ dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
211
+
212
+ dispatch_sync(serialQueue, ^(){
213
+ // Start a compute pass.
214
+ id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
215
+ TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
216
+
217
+ // Get Metal buffers directly from PyTorch tensors
218
+ auto attr_buf = getMTLBufferStorage(attr.contiguous());
219
+ auto indices_buf = getMTLBufferStorage(indices.contiguous());
220
+ auto rast_buf = getMTLBufferStorage(rast.contiguous());
221
+ auto pos_bake_buf = getMTLBufferStorage(pos_bake);
222
+
223
+ int width = rast.size(0);
224
+ int height = rast.size(1);
225
+
226
+ [computeEncoder setComputePipelineState:interpolate_PSO];
227
+ [computeEncoder setBuffer:attr_buf offset:attr.storage_offset() * attr.element_size() atIndex:0];
228
+ [computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
229
+ [computeEncoder setBuffer:rast_buf offset:rast.storage_offset() * rast.element_size() atIndex:2];
230
+ [computeEncoder setBuffer:pos_bake_buf offset:pos_bake.storage_offset() * pos_bake.element_size() atIndex:3];
231
+ [computeEncoder setBytes:&width length:sizeof(int) atIndex:4];
232
+ [computeEncoder setBytes:&height length:sizeof(int) atIndex:5];
233
+
234
+ // Calculate a thread group size.
235
+
236
+ int block_size = 16;
237
+ MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
238
+ MTLSize numThreadgroups = MTLSizeMake(rast.size(0) / block_size, rast.size(0) / block_size, 1);
239
+
240
+ // Encode the compute command.
241
+ [computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
242
+
243
+ [computeEncoder endEncoding];
244
+
245
+ // Commit the work.
246
+ torch::mps::commit();
247
+ });
248
+ }
249
+
250
+ return pos_bake;
251
+ }
252
+
253
+ // Registers MPS implementations
254
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, MPS, m)
255
+ {
256
+ m.impl("rasterize", &rasterize_gpu);
257
+ m.impl("interpolate", &interpolate_gpu);
258
+ }
259
+
260
+ }
uv_unwrapper/README.md ADDED
File without changes
uv_unwrapper/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ numpy
uv_unwrapper/setup.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import torch
5
+ from setuptools import find_packages, setup
6
+ from torch.utils.cpp_extension import (
7
+ BuildExtension,
8
+ CppExtension,
9
+ )
10
+
11
+ library_name = "uv_unwrapper"
12
+
13
+
14
+ def get_extensions():
15
+ debug_mode = os.getenv("DEBUG", "0") == "1"
16
+ if debug_mode:
17
+ print("Compiling in debug mode")
18
+
19
+ is_mac = True if torch.backends.mps.is_available() else False
20
+ use_native_arch = not is_mac and os.getenv("USE_NATIVE_ARCH", "1") == "1"
21
+ extension = CppExtension
22
+
23
+ extra_link_args = []
24
+ extra_compile_args = {
25
+ "cxx": (
26
+ [
27
+ "-O3" if not debug_mode else "-O0",
28
+ "-fdiagnostics-color=always",
29
+ ("-Xclang " if is_mac else "") + "-fopenmp",
30
+ ]
31
+ + ["-march=native"]
32
+ if use_native_arch
33
+ else [] + ["-mmacosx-version-min=10.15"] if is_mac else []
34
+ ),
35
+ }
36
+ if debug_mode:
37
+ extra_compile_args["cxx"].append("-g")
38
+ extra_compile_args["cxx"].append("-UNDEBUG")
39
+ extra_link_args.extend(["-O0", "-g"])
40
+
41
+ define_macros = []
42
+ extensions = []
43
+
44
+ this_dir = os.path.dirname(os.path.curdir)
45
+ sources = glob.glob(
46
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
47
+ )
48
+
49
+ if len(sources) == 0:
50
+ print("No source files found for extension, skipping extension compilation")
51
+ return None
52
+
53
+ extensions.append(
54
+ extension(
55
+ name=f"{library_name}._C",
56
+ sources=sources,
57
+ define_macros=define_macros,
58
+ extra_compile_args=extra_compile_args,
59
+ extra_link_args=extra_link_args,
60
+ libraries=(
61
+ ["c10", "torch", "torch_cpu", "torch_python"] + ["omp"]
62
+ if is_mac
63
+ else []
64
+ ),
65
+ )
66
+ )
67
+
68
+ print(extensions)
69
+
70
+ return extensions
71
+
72
+
73
+ setup(
74
+ name=library_name,
75
+ version="0.0.1",
76
+ packages=find_packages(),
77
+ ext_modules=get_extensions(),
78
+ install_requires=[],
79
+ description="Box projection based UV unwrapper",
80
+ long_description=open("README.md").read(),
81
+ long_description_content_type="text/markdown",
82
+ cmdclass={"build_ext": BuildExtension},
83
+ )
uv_unwrapper/uv_unwrapper/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch # noqa: F401
2
+
3
+ from . import _C # noqa: F401
4
+ from .unwrap import Unwrapper
5
+
6
+ __all__ = ["Unwrapper"]
uv_unwrapper/uv_unwrapper/csrc/bvh.cpp ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #include "bvh.h"
4
+ #include "common.h"
5
+ #include <cstring>
6
+ #include <iostream>
7
+ #include <queue>
8
+ #include <tuple>
9
+ #include <utility>
10
+
11
+ namespace UVUnwrapper {
12
+ BVH::BVH(Triangle *tri, int *actual_idx, const size_t &num_indices) {
13
+ // Copty tri to triangle
14
+ triangle = new Triangle[num_indices];
15
+ memcpy(triangle, tri, num_indices * sizeof(Triangle));
16
+
17
+ // Copy actual_idx to actualIdx
18
+ actualIdx = new int[num_indices];
19
+ memcpy(actualIdx, actual_idx, num_indices * sizeof(int));
20
+
21
+ triIdx = new int[num_indices];
22
+ triCount = num_indices;
23
+
24
+ bvhNode = new BVHNode[triCount * 2 + 64];
25
+ nodesUsed = 2;
26
+ memset(bvhNode, 0, triCount * 2 * sizeof(BVHNode));
27
+
28
+ // populate triangle index array
29
+ for (int i = 0; i < triCount; i++)
30
+ triIdx[i] = i;
31
+
32
+ BVHNode &root = bvhNode[0];
33
+
34
+ root.start = 0, root.end = triCount;
35
+ AABB centroidBounds;
36
+ UpdateNodeBounds(0, centroidBounds);
37
+
38
+ // subdivide recursively
39
+ Subdivide(0, nodesUsed, centroidBounds);
40
+ }
41
+
42
+ BVH::BVH(const BVH &other)
43
+ : BVH(other.triangle, other.triIdx, other.triCount) {}
44
+
45
+ BVH::BVH(BVH &&other) noexcept // move constructor
46
+ : triIdx(std::exchange(other.triIdx, nullptr)),
47
+ actualIdx(std::exchange(other.actualIdx, nullptr)),
48
+ triangle(std::exchange(other.triangle, nullptr)),
49
+ bvhNode(std::exchange(other.bvhNode, nullptr)) {}
50
+
51
+ BVH &BVH::operator=(const BVH &other) // copy assignment
52
+ {
53
+ return *this = BVH(other);
54
+ }
55
+
56
+ BVH &BVH::operator=(BVH &&other) noexcept // move assignment
57
+ {
58
+ std::swap(triIdx, other.triIdx);
59
+ std::swap(actualIdx, other.actualIdx);
60
+ std::swap(triangle, other.triangle);
61
+ std::swap(bvhNode, other.bvhNode);
62
+ std::swap(triCount, other.triCount);
63
+ std::swap(nodesUsed, other.nodesUsed);
64
+ return *this;
65
+ }
66
+
67
+ BVH::~BVH() {
68
+ if (triIdx)
69
+ delete[] triIdx;
70
+ if (triangle)
71
+ delete[] triangle;
72
+ if (actualIdx)
73
+ delete[] actualIdx;
74
+ if (bvhNode)
75
+ delete[] bvhNode;
76
+ }
77
+
78
+ void BVH::UpdateNodeBounds(unsigned int nodeIdx, AABB &centroidBounds) {
79
+ BVHNode &node = bvhNode[nodeIdx];
80
+ #ifndef __ARM_ARCH_ISA_A64
81
+ #ifndef _MSC_VER
82
+ if (__builtin_cpu_supports("sse"))
83
+ #elif (defined(_M_AMD64) || defined(_M_X64))
84
+ // SSE supported on Windows
85
+ if constexpr (true)
86
+ #endif
87
+ {
88
+ __m128 min4 = _mm_set_ps1(FLT_MAX), max4 = _mm_set_ps1(FLT_MIN);
89
+ __m128 cmin4 = _mm_set_ps1(FLT_MAX), cmax4 = _mm_set_ps1(FLT_MIN);
90
+ for (int i = node.start; i < node.end; i += 2) {
91
+ Triangle &leafTri1 = triangle[triIdx[i]];
92
+ __m128 v0, v1, v2, centroid;
93
+ if (i + 1 < node.end) {
94
+ const Triangle leafTri2 = triangle[triIdx[i + 1]];
95
+
96
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
97
+ leafTri2.v0.y);
98
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
99
+ leafTri2.v1.y);
100
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
101
+ leafTri2.v2.y);
102
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
103
+ leafTri2.centroid.x, leafTri2.centroid.y);
104
+ } else {
105
+ // Otherwise do some duplicated work
106
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
107
+ leafTri1.v0.y);
108
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
109
+ leafTri1.v1.y);
110
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
111
+ leafTri1.v2.y);
112
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
113
+ leafTri1.centroid.x, leafTri1.centroid.y);
114
+ }
115
+
116
+ min4 = _mm_min_ps(min4, v0);
117
+ max4 = _mm_max_ps(max4, v0);
118
+ min4 = _mm_min_ps(min4, v1);
119
+ max4 = _mm_max_ps(max4, v1);
120
+ min4 = _mm_min_ps(min4, v2);
121
+ max4 = _mm_max_ps(max4, v2);
122
+ cmin4 = _mm_min_ps(cmin4, centroid);
123
+ cmax4 = _mm_max_ps(cmax4, centroid);
124
+ }
125
+ float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
126
+ _mm_store_ps(min_values, min4);
127
+ _mm_store_ps(max_values, max4);
128
+ _mm_store_ps(cmin_values, cmin4);
129
+ _mm_store_ps(cmax_values, cmax4);
130
+
131
+ node.bbox.min.x = std::min(min_values[3], min_values[1]);
132
+ node.bbox.min.y = std::min(min_values[2], min_values[0]);
133
+ node.bbox.max.x = std::max(max_values[3], max_values[1]);
134
+ node.bbox.max.y = std::max(max_values[2], max_values[0]);
135
+
136
+ centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
137
+ centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
138
+ centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
139
+ centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
140
+ }
141
+ #else
142
+ if constexpr (false) {
143
+ }
144
+ #endif
145
+ else {
146
+ node.bbox.invalidate();
147
+ centroidBounds.invalidate();
148
+
149
+ // Calculate the bounding box for the node
150
+ for (int i = node.start; i < node.end; ++i) {
151
+ const Triangle &tri = triangle[triIdx[i]];
152
+ node.bbox.grow(tri.v0);
153
+ node.bbox.grow(tri.v1);
154
+ node.bbox.grow(tri.v2);
155
+ centroidBounds.grow(tri.centroid);
156
+ }
157
+ }
158
+ }
159
+
160
+ void BVH::Subdivide(unsigned int root_idx, unsigned int &nodePtr,
161
+ AABB &rootCentroidBounds) {
162
+ // Create a queue for the nodes to be subdivided
163
+ std::queue<std::tuple<unsigned int, AABB>> nodeQueue;
164
+ nodeQueue.push(std::make_tuple(root_idx, rootCentroidBounds));
165
+
166
+ while (!nodeQueue.empty()) {
167
+ // Get the next node to process from the queue
168
+ auto [node_idx, centroidBounds] = nodeQueue.front();
169
+ nodeQueue.pop();
170
+ BVHNode &node = bvhNode[node_idx];
171
+
172
+ // Check if left is -1 and right not or vice versa
173
+
174
+ int axis, splitPos;
175
+ float cost = FindBestSplitPlane(node, axis, splitPos, centroidBounds);
176
+
177
+ if (cost >= node.calculate_node_cost()) {
178
+ node.left = node.right = -1;
179
+ continue; // Move on to the next node in the queue
180
+ }
181
+
182
+ int i = node.start;
183
+ int j = node.end - 1;
184
+ float scale = BINS / (centroidBounds.max[axis] - centroidBounds.min[axis]);
185
+ while (i <= j) {
186
+ int binIdx =
187
+ std::min(BINS - 1, (int)((triangle[triIdx[i]].centroid[axis] -
188
+ centroidBounds.min[axis]) *
189
+ scale));
190
+ if (binIdx < splitPos)
191
+ i++;
192
+ else
193
+ std::swap(triIdx[i], triIdx[j--]);
194
+ }
195
+
196
+ int leftCount = i - node.start;
197
+ if (leftCount == 0 || leftCount == (int)node.num_triangles()) {
198
+ node.left = node.right = -1;
199
+ continue; // Move on to the next node in the queue
200
+ }
201
+
202
+ int mid = i;
203
+
204
+ // Create child nodes
205
+ int leftChildIdx = nodePtr++;
206
+ int rightChildIdx = nodePtr++;
207
+ bvhNode[leftChildIdx].start = node.start;
208
+ bvhNode[leftChildIdx].end = mid;
209
+ bvhNode[rightChildIdx].start = mid;
210
+ bvhNode[rightChildIdx].end = node.end;
211
+ node.left = leftChildIdx;
212
+ node.right = rightChildIdx;
213
+
214
+ // Update the bounds for the child nodes and push them onto the queue
215
+ UpdateNodeBounds(leftChildIdx, centroidBounds);
216
+ nodeQueue.push(std::make_tuple(leftChildIdx, centroidBounds));
217
+
218
+ UpdateNodeBounds(rightChildIdx, centroidBounds);
219
+ nodeQueue.push(std::make_tuple(rightChildIdx, centroidBounds));
220
+ }
221
+ }
222
+
223
+ float BVH::FindBestSplitPlane(BVHNode &node, int &best_axis, int &best_pos,
224
+ AABB &centroidBounds) {
225
+ float best_cost = FLT_MAX;
226
+
227
+ for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
228
+ {
229
+ float boundsMin = centroidBounds.min[axis];
230
+ float boundsMax = centroidBounds.max[axis];
231
+ // Or floating point precision
232
+ if ((boundsMin == boundsMax) || (boundsMax - boundsMin < 1e-8f)) {
233
+ continue;
234
+ }
235
+
236
+ // populate the bins
237
+ float scale = BINS / (boundsMax - boundsMin);
238
+ float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
239
+ int leftSum = 0, rightSum = 0;
240
+ #ifndef __ARM_ARCH_ISA_A64
241
+ #ifndef _MSC_VER
242
+ if (__builtin_cpu_supports("sse"))
243
+ #elif (defined(_M_AMD64) || defined(_M_X64))
244
+ // SSE supported on Windows
245
+ if constexpr (true)
246
+ #endif
247
+ {
248
+ __m128 min4[BINS], max4[BINS];
249
+ unsigned int count[BINS];
250
+ for (unsigned int i = 0; i < BINS; i++)
251
+ min4[i] = _mm_set_ps1(FLT_MAX), max4[i] = _mm_set_ps1(FLT_MIN),
252
+ count[i] = 0;
253
+ for (int i = node.start; i < node.end; i++) {
254
+ Triangle &tri = triangle[triIdx[i]];
255
+ int binIdx =
256
+ std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
257
+ count[binIdx]++;
258
+
259
+ __m128 v0 = _mm_set_ps(tri.v0.x, tri.v0.y, 0.0f, 0.0f);
260
+ __m128 v1 = _mm_set_ps(tri.v1.x, tri.v1.y, 0.0f, 0.0f);
261
+ __m128 v2 = _mm_set_ps(tri.v2.x, tri.v2.y, 0.0f, 0.0f);
262
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
263
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
264
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
265
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
266
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
267
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
268
+ }
269
+ // gather data for the 7 planes between the 8 bins
270
+ __m128 leftMin4 = _mm_set_ps1(FLT_MAX), rightMin4 = leftMin4;
271
+ __m128 leftMax4 = _mm_set_ps1(FLT_MIN), rightMax4 = leftMax4;
272
+ for (int i = 0; i < BINS - 1; i++) {
273
+ leftSum += count[i];
274
+ rightSum += count[BINS - 1 - i];
275
+ leftMin4 = _mm_min_ps(leftMin4, min4[i]);
276
+ rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
277
+ leftMax4 = _mm_max_ps(leftMax4, max4[i]);
278
+ rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
279
+ float le[4], re[4];
280
+ _mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
281
+ _mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
282
+ // SSE order goes from back to front
283
+ leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
284
+ rightCountArea[BINS - 2 - i] =
285
+ rightSum * (re[2] * re[3]); // 2D area calculation
286
+ }
287
+ }
288
+ #else
289
+ if constexpr (false) {
290
+ }
291
+ #endif
292
+ else {
293
+ struct Bin {
294
+ AABB bounds;
295
+ int triCount = 0;
296
+ } bin[BINS];
297
+ for (int i = node.start; i < node.end; i++) {
298
+ Triangle &tri = triangle[triIdx[i]];
299
+ int binIdx =
300
+ std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
301
+ bin[binIdx].triCount++;
302
+ bin[binIdx].bounds.grow(tri.v0);
303
+ bin[binIdx].bounds.grow(tri.v1);
304
+ bin[binIdx].bounds.grow(tri.v2);
305
+ }
306
+ // gather data for the 7 planes between the 8 bins
307
+ AABB leftBox, rightBox;
308
+ for (int i = 0; i < BINS - 1; i++) {
309
+ leftSum += bin[i].triCount;
310
+ leftBox.grow(bin[i].bounds);
311
+ leftCountArea[i] = leftSum * leftBox.area();
312
+ rightSum += bin[BINS - 1 - i].triCount;
313
+ rightBox.grow(bin[BINS - 1 - i].bounds);
314
+ rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
315
+ }
316
+ }
317
+
318
+ // calculate SAH cost for the 7 planes
319
+ scale = (boundsMax - boundsMin) / BINS;
320
+ for (int i = 0; i < BINS - 1; i++) {
321
+ const float planeCost = leftCountArea[i] + rightCountArea[i];
322
+ if (planeCost < best_cost)
323
+ best_axis = axis, best_pos = i + 1, best_cost = planeCost;
324
+ }
325
+ }
326
+ return best_cost;
327
+ }
328
+
329
+ std::vector<int> BVH::Intersect(Triangle &tri_intersect) {
330
+ /**
331
+ * @brief Intersect a triangle with the BVH
332
+ *
333
+ * @param triangle the triangle to intersect
334
+ *
335
+ * @return -1 for no intersection, the index of the intersected triangle
336
+ * otherwise
337
+ */
338
+
339
+ const int max_stack_size = 64;
340
+ int node_stack[max_stack_size];
341
+ int stack_size = 0;
342
+ std::vector<int> intersected_triangles;
343
+
344
+ node_stack[stack_size++] = 0; // Start with the root node (index 0)
345
+ while (stack_size > 0) {
346
+ int node_idx = node_stack[--stack_size];
347
+ const BVHNode &node = bvhNode[node_idx];
348
+ if (node.is_leaf()) {
349
+ for (int i = node.start; i < node.end; ++i) {
350
+ const Triangle &tri = triangle[triIdx[i]];
351
+ // Check that the triangle is not the same as the intersected triangle
352
+ if (tri == tri_intersect)
353
+ continue;
354
+ if (tri_intersect.overlaps(tri)) {
355
+ intersected_triangles.push_back(actualIdx[triIdx[i]]);
356
+ }
357
+ }
358
+ } else {
359
+ // Check right child first
360
+ if (bvhNode[node.right].bbox.overlaps(tri_intersect)) {
361
+ if (stack_size < max_stack_size) {
362
+ node_stack[stack_size++] = node.right;
363
+ } else {
364
+ throw std::runtime_error("Node stack overflow");
365
+ }
366
+ }
367
+
368
+ // Check left child
369
+ if (bvhNode[node.left].bbox.overlaps(tri_intersect)) {
370
+ if (stack_size < max_stack_size) {
371
+ node_stack[stack_size++] = node.left;
372
+ } else {
373
+ throw std::runtime_error("Node stack overflow");
374
+ }
375
+ }
376
+ }
377
+ }
378
+ return intersected_triangles; // Return all intersected triangle indices
379
+ }
380
+
381
+ } // namespace UVUnwrapper
uv_unwrapper/uv_unwrapper/csrc/bvh.h ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cfloat>
4
+ #include <cmath>
5
+ #ifndef __ARM_ARCH_ISA_A64
6
+ #include <immintrin.h>
7
+ #endif
8
+ #include <limits>
9
+ #include <vector>
10
+
11
+ #include "common.h"
12
+ #include "intersect.h"
13
+ /**
14
+ * Based on https://github.com/jbikker/bvh_article released under the unlicense.
15
+ */
16
+
17
+ // bin count for binned BVH building
18
+ #define BINS 8
19
+
20
+ namespace UVUnwrapper {
21
+ // minimalist triangle struct
22
+ struct alignas(32) Triangle {
23
+ uv_float2 v0;
24
+ uv_float2 v1;
25
+ uv_float2 v2;
26
+ uv_float2 centroid;
27
+
28
+ bool overlaps(const Triangle &other) {
29
+ // return tri_tri_overlap_test_2d(v0, v1, v2, other.v0, other.v1, other.v2);
30
+ return triangle_triangle_intersection(v0, v1, v2, other.v0, other.v1,
31
+ other.v2);
32
+ }
33
+
34
+ bool operator==(const Triangle &rhs) const {
35
+ return v0 == rhs.v0 && v1 == rhs.v1 && v2 == rhs.v2;
36
+ }
37
+ };
38
+
39
+ // minimalist AABB struct with grow functionality
40
+ struct alignas(16) AABB {
41
+ // Init bounding boxes with max/min
42
+ uv_float2 min = {FLT_MAX, FLT_MAX};
43
+ uv_float2 max = {FLT_MIN, FLT_MIN};
44
+
45
+ void grow(const uv_float2 &p) {
46
+ min.x = std::min(min.x, p.x);
47
+ min.y = std::min(min.y, p.y);
48
+ max.x = std::max(max.x, p.x);
49
+ max.y = std::max(max.y, p.y);
50
+ }
51
+
52
+ void grow(const AABB &b) {
53
+ if (b.min.x != FLT_MAX) {
54
+ grow(b.min);
55
+ grow(b.max);
56
+ }
57
+ }
58
+
59
+ bool overlaps(const Triangle &tri) {
60
+ return triangle_aabb_intersection(min, max, tri.v0, tri.v1, tri.v2);
61
+ }
62
+
63
+ float area() const {
64
+ uv_float2 extent = {max.x - min.x, max.y - min.y};
65
+ return extent.x * extent.y;
66
+ }
67
+
68
+ void invalidate() {
69
+ min = {FLT_MAX, FLT_MAX};
70
+ max = {FLT_MIN, FLT_MIN};
71
+ }
72
+ };
73
+
74
+ // 32-byte BVH node struct
75
+ struct alignas(32) BVHNode {
76
+ AABB bbox; // 16
77
+ int start = 0, end = 0; // 8
78
+ int left, right;
79
+
80
+ int num_triangles() const { return end - start; }
81
+
82
+ bool is_leaf() const { return left == -1 && right == -1; }
83
+
84
+ float calculate_node_cost() {
85
+ float area = bbox.area();
86
+ return num_triangles() * area;
87
+ }
88
+ };
89
+
90
+ class BVH {
91
+ public:
92
+ BVH() = default;
93
+ BVH(BVH &&other) noexcept;
94
+ BVH(const BVH &other);
95
+ BVH &operator=(const BVH &other);
96
+ BVH &operator=(BVH &&other) noexcept;
97
+ BVH(Triangle *tri, int *actual_idx, const size_t &num_indices);
98
+ ~BVH();
99
+
100
+ std::vector<int> Intersect(Triangle &triangle);
101
+
102
+ private:
103
+ void Subdivide(unsigned int node_idx, unsigned int &nodePtr,
104
+ AABB &centroidBounds);
105
+ void UpdateNodeBounds(unsigned int nodeIdx, AABB &centroidBounds);
106
+ float FindBestSplitPlane(BVHNode &node, int &axis, int &splitPos,
107
+ AABB &centroidBounds);
108
+
109
+ public:
110
+ int *triIdx = nullptr;
111
+ int *actualIdx = nullptr;
112
+ unsigned int triCount;
113
+ unsigned int nodesUsed;
114
+ BVHNode *bvhNode = nullptr;
115
+ Triangle *triangle = nullptr;
116
+ };
117
+
118
+ } // namespace UVUnwrapper
uv_unwrapper/uv_unwrapper/csrc/common.h ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <array>
4
+ #include <cmath>
5
+ #include <iostream>
6
+ #include <stdexcept>
7
+
8
+ const float EPSILON = 1e-7f;
9
+
10
+ // Structure to represent a 2D point or vector
11
+ union alignas(8) uv_float2 {
12
+ struct {
13
+ float x, y;
14
+ };
15
+
16
+ float data[2];
17
+
18
+ float &operator[](size_t idx) {
19
+ if (idx > 1)
20
+ throw std::runtime_error("bad index");
21
+ return data[idx];
22
+ }
23
+
24
+ const float &operator[](size_t idx) const {
25
+ if (idx > 1)
26
+ throw std::runtime_error("bad index");
27
+ return data[idx];
28
+ }
29
+
30
+ bool operator==(const uv_float2 &rhs) const {
31
+ return x == rhs.x && y == rhs.y;
32
+ }
33
+ };
34
+
35
+ // Do not align as this is specifically tweaked for BVHNode
36
+ union uv_float3 {
37
+ struct {
38
+ float x, y, z;
39
+ };
40
+
41
+ float data[3];
42
+
43
+ float &operator[](size_t idx) {
44
+ if (idx > 3)
45
+ throw std::runtime_error("bad index");
46
+ return data[idx];
47
+ }
48
+
49
+ const float &operator[](size_t idx) const {
50
+ if (idx > 3)
51
+ throw std::runtime_error("bad index");
52
+ return data[idx];
53
+ }
54
+
55
+ bool operator==(const uv_float3 &rhs) const {
56
+ return x == rhs.x && y == rhs.y && z == rhs.z;
57
+ }
58
+ };
59
+
60
+ union alignas(16) uv_float4 {
61
+ struct {
62
+ float x, y, z, w;
63
+ };
64
+
65
+ float data[4];
66
+
67
+ float &operator[](size_t idx) {
68
+ if (idx > 3)
69
+ throw std::runtime_error("bad index");
70
+ return data[idx];
71
+ }
72
+
73
+ const float &operator[](size_t idx) const {
74
+ if (idx > 3)
75
+ throw std::runtime_error("bad index");
76
+ return data[idx];
77
+ }
78
+
79
+ bool operator==(const uv_float4 &rhs) const {
80
+ return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
81
+ }
82
+ };
83
+
84
+ union alignas(8) uv_int2 {
85
+ struct {
86
+ int x, y;
87
+ };
88
+
89
+ int data[2];
90
+
91
+ int &operator[](size_t idx) {
92
+ if (idx > 1)
93
+ throw std::runtime_error("bad index");
94
+ return data[idx];
95
+ }
96
+
97
+ const int &operator[](size_t idx) const {
98
+ if (idx > 1)
99
+ throw std::runtime_error("bad index");
100
+ return data[idx];
101
+ }
102
+
103
+ bool operator==(const uv_int2 &rhs) const { return x == rhs.x && y == rhs.y; }
104
+ };
105
+
106
+ union alignas(4) uv_int3 {
107
+ struct {
108
+ int x, y, z;
109
+ };
110
+
111
+ int data[3];
112
+
113
+ int &operator[](size_t idx) {
114
+ if (idx > 2)
115
+ throw std::runtime_error("bad index");
116
+ return data[idx];
117
+ }
118
+
119
+ const int &operator[](size_t idx) const {
120
+ if (idx > 2)
121
+ throw std::runtime_error("bad index");
122
+ return data[idx];
123
+ }
124
+
125
+ bool operator==(const uv_int3 &rhs) const {
126
+ return x == rhs.x && y == rhs.y && z == rhs.z;
127
+ }
128
+ };
129
+
130
+ union alignas(16) uv_int4 {
131
+ struct {
132
+ int x, y, z, w;
133
+ };
134
+
135
+ int data[4];
136
+
137
+ int &operator[](size_t idx) {
138
+ if (idx > 3)
139
+ throw std::runtime_error("bad index");
140
+ return data[idx];
141
+ }
142
+
143
+ const int &operator[](size_t idx) const {
144
+ if (idx > 3)
145
+ throw std::runtime_error("bad index");
146
+ return data[idx];
147
+ }
148
+
149
+ bool operator==(const uv_int4 &rhs) const {
150
+ return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
151
+ }
152
+ };
153
+
154
+ inline float calc_mean(float a, float b, float c) { return (a + b + c) / 3; }
155
+
156
+ // Create a triangle centroid
157
+ inline uv_float2 triangle_centroid(const uv_float2 &v0, const uv_float2 &v1,
158
+ const uv_float2 &v2) {
159
+ return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y)};
160
+ }
161
+
162
+ inline uv_float3 triangle_centroid(const uv_float3 &v0, const uv_float3 &v1,
163
+ const uv_float3 &v2) {
164
+ return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y),
165
+ calc_mean(v0.z, v1.z, v2.z)};
166
+ }
167
+
168
+ // Helper functions for vector math
169
+ inline uv_float2 operator-(const uv_float2 &a, const uv_float2 &b) {
170
+ return {a.x - b.x, a.y - b.y};
171
+ }
172
+
173
+ inline uv_float3 operator-(const uv_float3 &a, const uv_float3 &b) {
174
+ return {a.x - b.x, a.y - b.y, a.z - b.z};
175
+ }
176
+
177
+ inline uv_float2 operator+(const uv_float2 &a, const uv_float2 &b) {
178
+ return {a.x + b.x, a.y + b.y};
179
+ }
180
+
181
+ inline uv_float3 operator+(const uv_float3 &a, const uv_float3 &b) {
182
+ return {a.x + b.x, a.y + b.y, a.z + b.z};
183
+ }
184
+
185
+ inline uv_float2 operator*(const uv_float2 &a, float scalar) {
186
+ return {a.x * scalar, a.y * scalar};
187
+ }
188
+
189
+ inline uv_float3 operator*(const uv_float3 &a, float scalar) {
190
+ return {a.x * scalar, a.y * scalar, a.z * scalar};
191
+ }
192
+
193
+ inline float dot(const uv_float2 &a, const uv_float2 &b) {
194
+ return a.x * b.x + a.y * b.y;
195
+ }
196
+
197
+ inline float dot(const uv_float3 &a, const uv_float3 &b) {
198
+ return a.x * b.x + a.y * b.y + a.z * b.z;
199
+ }
200
+
201
+ inline float cross(const uv_float2 &a, const uv_float2 &b) {
202
+ return a.x * b.y - a.y * b.x;
203
+ }
204
+
205
+ inline uv_float3 cross(const uv_float3 &a, const uv_float3 &b) {
206
+ return {a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x};
207
+ }
208
+
209
+ inline uv_float2 abs_vec(const uv_float2 &v) {
210
+ return {std::abs(v.x), std::abs(v.y)};
211
+ }
212
+
213
+ inline uv_float2 min_vec(const uv_float2 &a, const uv_float2 &b) {
214
+ return {std::min(a.x, b.x), std::min(a.y, b.y)};
215
+ }
216
+
217
+ inline uv_float2 max_vec(const uv_float2 &a, const uv_float2 &b) {
218
+ return {std::max(a.x, b.x), std::max(a.y, b.y)};
219
+ }
220
+
221
+ inline float distance_to(const uv_float2 &a, const uv_float2 &b) {
222
+ return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2));
223
+ }
224
+
225
+ inline float distance_to(const uv_float3 &a, const uv_float3 &b) {
226
+ return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2) +
227
+ std::pow(a.z - b.z, 2));
228
+ }
229
+
230
+ inline uv_float2 normalize(const uv_float2 &v) {
231
+ float len = std::sqrt(v.x * v.x + v.y * v.y);
232
+ return {v.x / len, v.y / len};
233
+ }
234
+
235
+ inline uv_float3 normalize(const uv_float3 &v) {
236
+ float len = std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
237
+ return {v.x / len, v.y / len, v.z / len};
238
+ }
239
+
240
+ inline float magnitude(const uv_float3 &v) {
241
+ return std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
242
+ }
243
+
244
+ struct Matrix4 {
245
+ std::array<std::array<float, 4>, 4> m;
246
+
247
+ Matrix4() {
248
+ for (auto &row : m) {
249
+ row.fill(0.0f);
250
+ }
251
+ m[3][3] = 1.0f; // Identity matrix for 4th row and column
252
+ }
253
+
254
+ void set(float m00, float m01, float m02, float m03, float m10, float m11,
255
+ float m12, float m13, float m20, float m21, float m22, float m23,
256
+ float m30, float m31, float m32, float m33) {
257
+ m[0][0] = m00;
258
+ m[0][1] = m01;
259
+ m[0][2] = m02;
260
+ m[0][3] = m03;
261
+ m[1][0] = m10;
262
+ m[1][1] = m11;
263
+ m[1][2] = m12;
264
+ m[1][3] = m13;
265
+ m[2][0] = m20;
266
+ m[2][1] = m21;
267
+ m[2][2] = m22;
268
+ m[2][3] = m23;
269
+ m[3][0] = m30;
270
+ m[3][1] = m31;
271
+ m[3][2] = m32;
272
+ m[3][3] = m33;
273
+ }
274
+
275
+ float determinant() const {
276
+ return m[0][3] * m[1][2] * m[2][1] * m[3][0] -
277
+ m[0][2] * m[1][3] * m[2][1] * m[3][0] -
278
+ m[0][3] * m[1][1] * m[2][2] * m[3][0] +
279
+ m[0][1] * m[1][3] * m[2][2] * m[3][0] +
280
+ m[0][2] * m[1][1] * m[2][3] * m[3][0] -
281
+ m[0][1] * m[1][2] * m[2][3] * m[3][0] -
282
+ m[0][3] * m[1][2] * m[2][0] * m[3][1] +
283
+ m[0][2] * m[1][3] * m[2][0] * m[3][1] +
284
+ m[0][3] * m[1][0] * m[2][2] * m[3][1] -
285
+ m[0][0] * m[1][3] * m[2][2] * m[3][1] -
286
+ m[0][2] * m[1][0] * m[2][3] * m[3][1] +
287
+ m[0][0] * m[1][2] * m[2][3] * m[3][1] +
288
+ m[0][3] * m[1][1] * m[2][0] * m[3][2] -
289
+ m[0][1] * m[1][3] * m[2][0] * m[3][2] -
290
+ m[0][3] * m[1][0] * m[2][1] * m[3][2] +
291
+ m[0][0] * m[1][3] * m[2][1] * m[3][2] +
292
+ m[0][1] * m[1][0] * m[2][3] * m[3][2] -
293
+ m[0][0] * m[1][1] * m[2][3] * m[3][2] -
294
+ m[0][2] * m[1][1] * m[2][0] * m[3][3] +
295
+ m[0][1] * m[1][2] * m[2][0] * m[3][3] +
296
+ m[0][2] * m[1][0] * m[2][1] * m[3][3] -
297
+ m[0][0] * m[1][2] * m[2][1] * m[3][3] -
298
+ m[0][1] * m[1][0] * m[2][2] * m[3][3] +
299
+ m[0][0] * m[1][1] * m[2][2] * m[3][3];
300
+ }
301
+
302
+ Matrix4 operator*(const Matrix4 &other) const {
303
+ Matrix4 result;
304
+ for (int row = 0; row < 4; ++row) {
305
+ for (int col = 0; col < 4; ++col) {
306
+ result.m[row][col] =
307
+ m[row][0] * other.m[0][col] + m[row][1] * other.m[1][col] +
308
+ m[row][2] * other.m[2][col] + m[row][3] * other.m[3][col];
309
+ }
310
+ }
311
+ return result;
312
+ }
313
+
314
+ Matrix4 operator*(float scalar) const {
315
+ Matrix4 result = *this;
316
+ for (auto &row : result.m) {
317
+ for (auto &element : row) {
318
+ element *= scalar;
319
+ }
320
+ }
321
+ return result;
322
+ }
323
+
324
+ Matrix4 operator+(const Matrix4 &other) const {
325
+ Matrix4 result;
326
+ for (int i = 0; i < 4; ++i) {
327
+ for (int j = 0; j < 4; ++j) {
328
+ result.m[i][j] = m[i][j] + other.m[i][j];
329
+ }
330
+ }
331
+ return result;
332
+ }
333
+
334
+ Matrix4 operator-(const Matrix4 &other) const {
335
+ Matrix4 result;
336
+ for (int i = 0; i < 4; ++i) {
337
+ for (int j = 0; j < 4; ++j) {
338
+ result.m[i][j] = m[i][j] - other.m[i][j];
339
+ }
340
+ }
341
+ return result;
342
+ }
343
+
344
+ float trace() const { return m[0][0] + m[1][1] + m[2][2] + m[3][3]; }
345
+
346
+ Matrix4 identity() const {
347
+ Matrix4 identity;
348
+ identity.set(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1);
349
+ return identity;
350
+ }
351
+
352
+ Matrix4 power(int exp) const {
353
+ if (exp == 0)
354
+ return identity();
355
+ if (exp == 1)
356
+ return *this;
357
+
358
+ Matrix4 result = *this;
359
+ for (int i = 1; i < exp; ++i) {
360
+ result = result * (*this);
361
+ }
362
+ return result;
363
+ }
364
+
365
+ void print() {
366
+ // Print all entries in 4 rows with 4 columns
367
+ for (int i = 0; i < 4; ++i) {
368
+ for (int j = 0; j < 4; ++j) {
369
+ std::cout << m[i][j] << " ";
370
+ }
371
+ std::cout << std::endl;
372
+ }
373
+ }
374
+
375
+ bool invert() {
376
+ double inv[16], det;
377
+ double mArr[16];
378
+
379
+ // Convert the matrix to a 1D array for easier manipulation
380
+ for (int i = 0; i < 4; ++i) {
381
+ for (int j = 0; j < 4; ++j) {
382
+ mArr[i * 4 + j] = static_cast<double>(m[i][j]);
383
+ }
384
+ }
385
+
386
+ inv[0] = mArr[5] * mArr[10] * mArr[15] - mArr[5] * mArr[11] * mArr[14] -
387
+ mArr[9] * mArr[6] * mArr[15] + mArr[9] * mArr[7] * mArr[14] +
388
+ mArr[13] * mArr[6] * mArr[11] - mArr[13] * mArr[7] * mArr[10];
389
+
390
+ inv[4] = -mArr[4] * mArr[10] * mArr[15] + mArr[4] * mArr[11] * mArr[14] +
391
+ mArr[8] * mArr[6] * mArr[15] - mArr[8] * mArr[7] * mArr[14] -
392
+ mArr[12] * mArr[6] * mArr[11] + mArr[12] * mArr[7] * mArr[10];
393
+
394
+ inv[8] = mArr[4] * mArr[9] * mArr[15] - mArr[4] * mArr[11] * mArr[13] -
395
+ mArr[8] * mArr[5] * mArr[15] + mArr[8] * mArr[7] * mArr[13] +
396
+ mArr[12] * mArr[5] * mArr[11] - mArr[12] * mArr[7] * mArr[9];
397
+
398
+ inv[12] = -mArr[4] * mArr[9] * mArr[14] + mArr[4] * mArr[10] * mArr[13] +
399
+ mArr[8] * mArr[5] * mArr[14] - mArr[8] * mArr[6] * mArr[13] -
400
+ mArr[12] * mArr[5] * mArr[10] + mArr[12] * mArr[6] * mArr[9];
401
+
402
+ inv[1] = -mArr[1] * mArr[10] * mArr[15] + mArr[1] * mArr[11] * mArr[14] +
403
+ mArr[9] * mArr[2] * mArr[15] - mArr[9] * mArr[3] * mArr[14] -
404
+ mArr[13] * mArr[2] * mArr[11] + mArr[13] * mArr[3] * mArr[10];
405
+
406
+ inv[5] = mArr[0] * mArr[10] * mArr[15] - mArr[0] * mArr[11] * mArr[14] -
407
+ mArr[8] * mArr[2] * mArr[15] + mArr[8] * mArr[3] * mArr[14] +
408
+ mArr[12] * mArr[2] * mArr[11] - mArr[12] * mArr[3] * mArr[10];
409
+
410
+ inv[9] = -mArr[0] * mArr[9] * mArr[15] + mArr[0] * mArr[11] * mArr[13] +
411
+ mArr[8] * mArr[1] * mArr[15] - mArr[8] * mArr[3] * mArr[13] -
412
+ mArr[12] * mArr[1] * mArr[11] + mArr[12] * mArr[3] * mArr[9];
413
+
414
+ inv[13] = mArr[0] * mArr[9] * mArr[14] - mArr[0] * mArr[10] * mArr[13] -
415
+ mArr[8] * mArr[1] * mArr[14] + mArr[8] * mArr[2] * mArr[13] +
416
+ mArr[12] * mArr[1] * mArr[10] - mArr[12] * mArr[2] * mArr[9];
417
+
418
+ inv[2] = mArr[1] * mArr[6] * mArr[15] - mArr[1] * mArr[7] * mArr[14] -
419
+ mArr[5] * mArr[2] * mArr[15] + mArr[5] * mArr[3] * mArr[14] +
420
+ mArr[13] * mArr[2] * mArr[7] - mArr[13] * mArr[3] * mArr[6];
421
+
422
+ inv[6] = -mArr[0] * mArr[6] * mArr[15] + mArr[0] * mArr[7] * mArr[14] +
423
+ mArr[4] * mArr[2] * mArr[15] - mArr[4] * mArr[3] * mArr[14] -
424
+ mArr[12] * mArr[2] * mArr[7] + mArr[12] * mArr[3] * mArr[6];
425
+
426
+ inv[10] = mArr[0] * mArr[5] * mArr[15] - mArr[0] * mArr[7] * mArr[13] -
427
+ mArr[4] * mArr[1] * mArr[15] + mArr[4] * mArr[3] * mArr[13] +
428
+ mArr[12] * mArr[1] * mArr[7] - mArr[12] * mArr[3] * mArr[5];
429
+
430
+ inv[14] = -mArr[0] * mArr[5] * mArr[14] + mArr[0] * mArr[6] * mArr[13] +
431
+ mArr[4] * mArr[1] * mArr[14] - mArr[4] * mArr[2] * mArr[13] -
432
+ mArr[12] * mArr[1] * mArr[6] + mArr[12] * mArr[2] * mArr[5];
433
+
434
+ inv[3] = -mArr[1] * mArr[6] * mArr[11] + mArr[1] * mArr[7] * mArr[10] +
435
+ mArr[5] * mArr[2] * mArr[11] - mArr[5] * mArr[3] * mArr[10] -
436
+ mArr[9] * mArr[2] * mArr[7] + mArr[9] * mArr[3] * mArr[6];
437
+
438
+ inv[7] = mArr[0] * mArr[6] * mArr[11] - mArr[0] * mArr[7] * mArr[10] -
439
+ mArr[4] * mArr[2] * mArr[11] + mArr[4] * mArr[3] * mArr[10] +
440
+ mArr[8] * mArr[2] * mArr[7] - mArr[8] * mArr[3] * mArr[6];
441
+
442
+ inv[11] = -mArr[0] * mArr[5] * mArr[11] + mArr[0] * mArr[7] * mArr[9] +
443
+ mArr[4] * mArr[1] * mArr[11] - mArr[4] * mArr[3] * mArr[9] -
444
+ mArr[8] * mArr[1] * mArr[7] + mArr[8] * mArr[3] * mArr[5];
445
+
446
+ inv[15] = mArr[0] * mArr[5] * mArr[10] - mArr[0] * mArr[6] * mArr[9] -
447
+ mArr[4] * mArr[1] * mArr[10] + mArr[4] * mArr[2] * mArr[9] +
448
+ mArr[8] * mArr[1] * mArr[6] - mArr[8] * mArr[2] * mArr[5];
449
+
450
+ det = mArr[0] * inv[0] + mArr[1] * inv[4] + mArr[2] * inv[8] +
451
+ mArr[3] * inv[12];
452
+
453
+ if (fabs(det) < 1e-6) {
454
+ return false;
455
+ }
456
+
457
+ det = 1.0 / det;
458
+
459
+ for (int i = 0; i < 16; i++) {
460
+ inv[i] *= det;
461
+ }
462
+
463
+ // Convert the 1D array back to the 4x4 matrix
464
+ for (int i = 0; i < 4; ++i) {
465
+ for (int j = 0; j < 4; ++j) {
466
+ m[i][j] = static_cast<float>(inv[i * 4 + j]);
467
+ }
468
+ }
469
+
470
+ return true;
471
+ }
472
+ };
473
+
474
+ inline void apply_matrix4(uv_float3 &v, const Matrix4 matrix) {
475
+ float newX = v.x * matrix.m[0][0] + v.y * matrix.m[0][1] +
476
+ v.z * matrix.m[0][2] + matrix.m[0][3];
477
+ float newY = v.x * matrix.m[1][0] + v.y * matrix.m[1][1] +
478
+ v.z * matrix.m[1][2] + matrix.m[1][3];
479
+ float newZ = v.x * matrix.m[2][0] + v.y * matrix.m[2][1] +
480
+ v.z * matrix.m[2][2] + matrix.m[2][3];
481
+ float w = v.x * matrix.m[3][0] + v.y * matrix.m[3][1] + v.z * matrix.m[3][2] +
482
+ matrix.m[3][3];
483
+
484
+ if (std::fabs(w) > EPSILON) {
485
+ newX /= w;
486
+ newY /= w;
487
+ newZ /= w;
488
+ }
489
+
490
+ v.x = newX;
491
+ v.y = newY;
492
+ v.z = newZ;
493
+ }
uv_unwrapper/uv_unwrapper/csrc/intersect.cpp ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "intersect.h"
2
+ #include "bvh.h"
3
+ #include <algorithm>
4
+ #include <cmath>
5
+ #include <iostream>
6
+ #include <stdexcept>
7
+ #include <vector>
8
+
9
+ bool triangle_aabb_intersection(const uv_float2 &aabbMin,
10
+ const uv_float2 &aabbMax, const uv_float2 &v0,
11
+ const uv_float2 &v1, const uv_float2 &v2) {
12
+ // Convert the min and max aabb defintion to left, right, top, bottom
13
+ float l = aabbMin.x;
14
+ float r = aabbMax.x;
15
+ float t = aabbMin.y;
16
+ float b = aabbMax.y;
17
+
18
+ int b0 = ((v0.x > l) ? 1 : 0) | ((v0.y > t) ? 2 : 0) | ((v0.x > r) ? 4 : 0) |
19
+ ((v0.y > b) ? 8 : 0);
20
+ if (b0 == 3)
21
+ return true;
22
+
23
+ int b1 = ((v1.x > l) ? 1 : 0) | ((v1.y > t) ? 2 : 0) | ((v1.x > r) ? 4 : 0) |
24
+ ((v1.y > b) ? 8 : 0);
25
+ if (b1 == 3)
26
+ return true;
27
+
28
+ int b2 = ((v2.x > l) ? 1 : 0) | ((v2.y > t) ? 2 : 0) | ((v2.x > r) ? 4 : 0) |
29
+ ((v2.y > b) ? 8 : 0);
30
+ if (b2 == 3)
31
+ return true;
32
+
33
+ float m, c, s;
34
+
35
+ int i0 = b0 ^ b1;
36
+ if (i0 != 0) {
37
+ if (v1.x != v0.x) {
38
+ m = (v1.y - v0.y) / (v1.x - v0.x);
39
+ c = v0.y - (m * v0.x);
40
+ if (i0 & 1) {
41
+ s = m * l + c;
42
+ if (s >= t && s <= b)
43
+ return true;
44
+ }
45
+ if (i0 & 2) {
46
+ s = (t - c) / m;
47
+ if (s >= l && s <= r)
48
+ return true;
49
+ }
50
+ if (i0 & 4) {
51
+ s = m * r + c;
52
+ if (s >= t && s <= b)
53
+ return true;
54
+ }
55
+ if (i0 & 8) {
56
+ s = (b - c) / m;
57
+ if (s >= l && s <= r)
58
+ return true;
59
+ }
60
+ } else {
61
+ if (l == v0.x || r == v0.x)
62
+ return true;
63
+ if (v0.x > l && v0.x < r)
64
+ return true;
65
+ }
66
+ }
67
+
68
+ int i1 = b1 ^ b2;
69
+ if (i1 != 0) {
70
+ if (v2.x != v1.x) {
71
+ m = (v2.y - v1.y) / (v2.x - v1.x);
72
+ c = v1.y - (m * v1.x);
73
+ if (i1 & 1) {
74
+ s = m * l + c;
75
+ if (s >= t && s <= b)
76
+ return true;
77
+ }
78
+ if (i1 & 2) {
79
+ s = (t - c) / m;
80
+ if (s >= l && s <= r)
81
+ return true;
82
+ }
83
+ if (i1 & 4) {
84
+ s = m * r + c;
85
+ if (s >= t && s <= b)
86
+ return true;
87
+ }
88
+ if (i1 & 8) {
89
+ s = (b - c) / m;
90
+ if (s >= l && s <= r)
91
+ return true;
92
+ }
93
+ } else {
94
+ if (l == v1.x || r == v1.x)
95
+ return true;
96
+ if (v1.x > l && v1.x < r)
97
+ return true;
98
+ }
99
+ }
100
+
101
+ int i2 = b0 ^ b2;
102
+ if (i2 != 0) {
103
+ if (v2.x != v0.x) {
104
+ m = (v2.y - v0.y) / (v2.x - v0.x);
105
+ c = v0.y - (m * v0.x);
106
+ if (i2 & 1) {
107
+ s = m * l + c;
108
+ if (s >= t && s <= b)
109
+ return true;
110
+ }
111
+ if (i2 & 2) {
112
+ s = (t - c) / m;
113
+ if (s >= l && s <= r)
114
+ return true;
115
+ }
116
+ if (i2 & 4) {
117
+ s = m * r + c;
118
+ if (s >= t && s <= b)
119
+ return true;
120
+ }
121
+ if (i2 & 8) {
122
+ s = (b - c) / m;
123
+ if (s >= l && s <= r)
124
+ return true;
125
+ }
126
+ } else {
127
+ if (l == v0.x || r == v0.x)
128
+ return true;
129
+ if (v0.x > l && v0.x < r)
130
+ return true;
131
+ }
132
+ }
133
+
134
+ // Bounding box check
135
+ float tbb_l = std::min(v0.x, std::min(v1.x, v2.x));
136
+ float tbb_t = std::min(v0.y, std::min(v1.y, v2.y));
137
+ float tbb_r = std::max(v0.x, std::max(v1.x, v2.x));
138
+ float tbb_b = std::max(v0.y, std::max(v1.y, v2.y));
139
+
140
+ if (tbb_l <= l && tbb_r >= r && tbb_t <= t && tbb_b >= b) {
141
+ float v0x = v2.x - v0.x;
142
+ float v0y = v2.y - v0.y;
143
+ float v1x = v1.x - v0.x;
144
+ float v1y = v1.y - v0.y;
145
+ float v2x, v2y;
146
+
147
+ float dot00, dot01, dot02, dot11, dot12, invDenom, u, v;
148
+
149
+ // Top-left corner
150
+ v2x = l - v0.x;
151
+ v2y = t - v0.y;
152
+
153
+ dot00 = v0x * v0x + v0y * v0y;
154
+ dot01 = v0x * v1x + v0y * v1y;
155
+ dot02 = v0x * v2x + v0y * v2y;
156
+ dot11 = v1x * v1x + v1y * v1y;
157
+ dot12 = v1x * v2x + v1y * v2y;
158
+
159
+ invDenom = 1.0f / (dot00 * dot11 - dot01 * dot01);
160
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
161
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
162
+
163
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
164
+ return true;
165
+
166
+ // Bottom-left corner
167
+ v2x = l - v0.x;
168
+ v2y = b - v0.y;
169
+
170
+ dot02 = v0x * v2x + v0y * v2y;
171
+ dot12 = v1x * v2x + v1y * v2y;
172
+
173
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
174
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
175
+
176
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
177
+ return true;
178
+
179
+ // Bottom-right corner
180
+ v2x = r - v0.x;
181
+ v2y = b - v0.y;
182
+
183
+ dot02 = v0x * v2x + v0y * v2y;
184
+ dot12 = v1x * v2x + v1y * v2y;
185
+
186
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
187
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
188
+
189
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
190
+ return true;
191
+
192
+ // Top-right corner
193
+ v2x = r - v0.x;
194
+ v2y = t - v0.y;
195
+
196
+ dot02 = v0x * v2x + v0y * v2y;
197
+ dot12 = v1x * v2x + v1y * v2y;
198
+
199
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
200
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
201
+
202
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
203
+ return true;
204
+ }
205
+
206
+ return false;
207
+ }
208
+
209
+ void tri_winding(uv_float2 &a, uv_float2 &b, uv_float2 &c) {
210
+ float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
211
+
212
+ // If the determinant is negative, the triangle is oriented clockwise
213
+ if (det < 0) {
214
+ // Swap vertices b and c to ensure counter-clockwise winding
215
+ std::swap(b, c);
216
+ }
217
+ }
218
+
219
+ struct Triangle {
220
+ uv_float3 a, b, c;
221
+
222
+ Triangle(const uv_float2 &p1, const uv_float2 &q1, const uv_float2 &r1)
223
+ : a({p1.x, p1.y, 0}), b({q1.x, q1.y, 0}), c({r1.x, r1.y, 0}) {}
224
+
225
+ Triangle(const uv_float3 &p1, const uv_float3 &q1, const uv_float3 &r1)
226
+ : a(p1), b(q1), c(r1) {}
227
+
228
+ void getNormal(uv_float3 &normal) const {
229
+ uv_float3 u = b - a;
230
+ uv_float3 v = c - a;
231
+ normal = normalize(cross(u, v));
232
+ }
233
+ };
234
+
235
+ bool isTriDegenerated(const Triangle &tri) {
236
+ uv_float3 u = tri.a - tri.b;
237
+ uv_float3 v = tri.a - tri.c;
238
+ uv_float3 cr = cross(u, v);
239
+ return fabs(cr.x) < EPSILON && fabs(cr.y) < EPSILON && fabs(cr.z) < EPSILON;
240
+ }
241
+
242
+ int orient3D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c,
243
+ const uv_float3 &d) {
244
+ Matrix4 _matrix4;
245
+ _matrix4.set(a.x, a.y, a.z, 1, b.x, b.y, b.z, 1, c.x, c.y, c.z, 1, d.x, d.y,
246
+ d.z, 1);
247
+ float det = _matrix4.determinant();
248
+
249
+ if (det < -EPSILON)
250
+ return -1;
251
+ else if (det > EPSILON)
252
+ return 1;
253
+ else
254
+ return 0;
255
+ }
256
+
257
+ int orient2D(const uv_float2 &a, const uv_float2 &b, const uv_float2 &c) {
258
+ float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
259
+
260
+ if (det < -EPSILON)
261
+ return -1;
262
+ else if (det > EPSILON)
263
+ return 1;
264
+ else
265
+ return 0;
266
+ }
267
+
268
+ int orient2D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c) {
269
+ uv_float2 a_2d = {a.x, a.y};
270
+ uv_float2 b_2d = {b.x, b.y};
271
+ uv_float2 c_2d = {c.x, c.y};
272
+ return orient2D(a_2d, b_2d, c_2d);
273
+ }
274
+
275
+ void permuteTriLeft(Triangle &tri) {
276
+ uv_float3 tmp = tri.a;
277
+ tri.a = tri.b;
278
+ tri.b = tri.c;
279
+ tri.c = tmp;
280
+ }
281
+
282
+ void permuteTriRight(Triangle &tri) {
283
+ uv_float3 tmp = tri.c;
284
+ tri.c = tri.b;
285
+ tri.b = tri.a;
286
+ tri.a = tmp;
287
+ }
288
+
289
+ void makeTriCounterClockwise(Triangle &tri) {
290
+ if (orient2D(tri.a, tri.b, tri.c) < 0) {
291
+ uv_float3 tmp = tri.c;
292
+ tri.c = tri.b;
293
+ tri.b = tmp;
294
+ }
295
+ }
296
+
297
+ void intersectPlane(const uv_float3 &a, const uv_float3 &b, const uv_float3 &p,
298
+ const uv_float3 &n, uv_float3 &target) {
299
+ uv_float3 u = b - a;
300
+ uv_float3 v = a - p;
301
+ float dot1 = dot(n, u);
302
+ float dot2 = dot(n, v);
303
+ u = u * (-dot2 / dot1);
304
+ target = a + u;
305
+ }
306
+
307
+ void computeLineIntersection(const Triangle &t1, const Triangle &t2,
308
+ std::vector<uv_float3> &target) {
309
+ uv_float3 n1, n2;
310
+ t1.getNormal(n1);
311
+ t2.getNormal(n2);
312
+
313
+ int o1 = orient3D(t1.a, t1.c, t2.b, t2.a);
314
+ int o2 = orient3D(t1.a, t1.b, t2.c, t2.a);
315
+
316
+ uv_float3 i1, i2;
317
+
318
+ if (o1 > 0) {
319
+ if (o2 > 0) {
320
+ intersectPlane(t1.a, t1.c, t2.a, n2, i1);
321
+ intersectPlane(t2.a, t2.c, t1.a, n1, i2);
322
+ } else {
323
+ intersectPlane(t1.a, t1.c, t2.a, n2, i1);
324
+ intersectPlane(t1.a, t1.b, t2.a, n2, i2);
325
+ }
326
+ } else {
327
+ if (o2 > 0) {
328
+ intersectPlane(t2.a, t2.b, t1.a, n1, i1);
329
+ intersectPlane(t2.a, t2.c, t1.a, n1, i2);
330
+ } else {
331
+ intersectPlane(t2.a, t2.b, t1.a, n1, i1);
332
+ intersectPlane(t1.a, t1.b, t2.a, n2, i2);
333
+ }
334
+ }
335
+
336
+ target.push_back(i1);
337
+ if (distance_to(i1, i2) >= EPSILON) {
338
+ target.push_back(i2);
339
+ }
340
+ }
341
+
342
+ void makeTriAVertexAlone(Triangle &tri, int oa, int ob, int oc) {
343
+ // Permute a, b, c so that a is alone on its side
344
+ if (oa == ob) {
345
+ // c is alone, permute right so c becomes a
346
+ permuteTriRight(tri);
347
+ } else if (oa == oc) {
348
+ // b is alone, permute so b becomes a
349
+ permuteTriLeft(tri);
350
+ } else if (ob != oc) {
351
+ // In case a, b, c have different orientation, put a on positive side
352
+ if (ob > 0) {
353
+ permuteTriLeft(tri);
354
+ } else if (oc > 0) {
355
+ permuteTriRight(tri);
356
+ }
357
+ }
358
+ }
359
+
360
+ void makeTriAVertexPositive(Triangle &tri, const Triangle &other) {
361
+ int o = orient3D(other.a, other.b, other.c, tri.a);
362
+ if (o < 0) {
363
+ std::swap(tri.b, tri.c);
364
+ }
365
+ }
366
+
367
+ bool crossIntersect(Triangle &t1, Triangle &t2, int o1a, int o1b, int o1c,
368
+ std::vector<uv_float3> *target = nullptr) {
369
+ int o2a = orient3D(t1.a, t1.b, t1.c, t2.a);
370
+ int o2b = orient3D(t1.a, t1.b, t1.c, t2.b);
371
+ int o2c = orient3D(t1.a, t1.b, t1.c, t2.c);
372
+
373
+ if (o2a == o2b && o2a == o2c) {
374
+ return false;
375
+ }
376
+
377
+ // Make a vertex alone on its side for both triangles
378
+ makeTriAVertexAlone(t1, o1a, o1b, o1c);
379
+ makeTriAVertexAlone(t2, o2a, o2b, o2c);
380
+
381
+ // Ensure the vertex on the positive side
382
+ makeTriAVertexPositive(t2, t1);
383
+ makeTriAVertexPositive(t1, t2);
384
+
385
+ int o1 = orient3D(t1.a, t1.b, t2.a, t2.b);
386
+ int o2 = orient3D(t1.a, t1.c, t2.c, t2.a);
387
+
388
+ if (o1 <= 0 && o2 <= 0) {
389
+ if (target) {
390
+ computeLineIntersection(t1, t2, *target);
391
+ }
392
+ return true;
393
+ }
394
+
395
+ return false;
396
+ }
397
+
398
+ void linesIntersect2d(const uv_float3 &a1, const uv_float3 &b1,
399
+ const uv_float3 &a2, const uv_float3 &b2,
400
+ uv_float3 &target) {
401
+ float dx1 = a1.x - b1.x;
402
+ float dx2 = a2.x - b2.x;
403
+ float dy1 = a1.y - b1.y;
404
+ float dy2 = a2.y - b2.y;
405
+
406
+ float D = dx1 * dy2 - dx2 * dy1;
407
+
408
+ float n1 = a1.x * b1.y - a1.y * b1.x;
409
+ float n2 = a2.x * b2.y - a2.y * b2.x;
410
+
411
+ target.x = (n1 * dx2 - n2 * dx1) / D;
412
+ target.y = (n1 * dy2 - n2 * dy1) / D;
413
+ target.z = 0;
414
+ }
415
+
416
+ void clipTriangle(const Triangle &t1, const Triangle &t2,
417
+ std::vector<uv_float3> &target) {
418
+ std::vector<uv_float3> clip = {t1.a, t1.b, t1.c};
419
+ std::vector<uv_float3> output = {t2.a, t2.b, t2.c};
420
+ std::vector<int> orients(output.size() * 3, 0);
421
+ uv_float3 inter;
422
+
423
+ for (int i = 0; i < 3; ++i) {
424
+ const int i_prev = (i + 2) % 3;
425
+ std::vector<uv_float3> input;
426
+ std::copy(output.begin(), output.end(), std::back_inserter(input));
427
+ output.clear();
428
+
429
+ for (size_t j = 0; j < input.size(); ++j) {
430
+ orients[j] = orient2D(clip[i_prev], clip[i], input[j]);
431
+ }
432
+
433
+ for (size_t j = 0; j < input.size(); ++j) {
434
+ const int j_prev = (j - 1 + input.size()) % input.size();
435
+
436
+ if (orients[j] >= 0) {
437
+ if (orients[j_prev] < 0) {
438
+ linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j],
439
+ inter);
440
+ output.push_back({inter.x, inter.y, inter.z});
441
+ }
442
+ output.push_back({input[j].x, input[j].y, input[j].z});
443
+ } else if (orients[j_prev] >= 0) {
444
+ linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j], inter);
445
+ output.push_back({inter.x, inter.y, inter.z});
446
+ }
447
+ }
448
+ }
449
+
450
+ // Clear duplicated points
451
+ for (const auto &point : output) {
452
+ int j = 0;
453
+ bool sameFound = false;
454
+ while (!sameFound && j < target.size()) {
455
+ sameFound = distance_to(point, target[j]) <= 1e-6;
456
+ j++;
457
+ }
458
+
459
+ if (!sameFound) {
460
+ target.push_back(point);
461
+ }
462
+ }
463
+ }
464
+
465
+ bool intersectionTypeR1(const Triangle &t1, const Triangle &t2) {
466
+ const uv_float3 &p1 = t1.a;
467
+ const uv_float3 &q1 = t1.b;
468
+ const uv_float3 &r1 = t1.c;
469
+ const uv_float3 &p2 = t2.a;
470
+ const uv_float3 &r2 = t2.c;
471
+
472
+ if (orient2D(r2, p2, q1) >= 0) { // I
473
+ if (orient2D(r2, p1, q1) >= 0) { // II.a
474
+ if (orient2D(p1, p2, q1) >= 0) { // III.a
475
+ return true;
476
+ } else {
477
+ if (orient2D(p1, p2, r1) >= 0) { // IV.a
478
+ if (orient2D(q1, r1, p2) >= 0) { // V
479
+ return true;
480
+ }
481
+ }
482
+ }
483
+ }
484
+ } else {
485
+ if (orient2D(r2, p2, r1) >= 0) { // II.b
486
+ if (orient2D(q1, r1, r2) >= 0) { // III.b
487
+ if (orient2D(p1, p2, r1) >= 0) { // IV.b (diverges from paper)
488
+ return true;
489
+ }
490
+ }
491
+ }
492
+ }
493
+
494
+ return false;
495
+ }
496
+
497
+ bool intersectionTypeR2(const Triangle &t1, const Triangle &t2) {
498
+ const uv_float3 &p1 = t1.a;
499
+ const uv_float3 &q1 = t1.b;
500
+ const uv_float3 &r1 = t1.c;
501
+ const uv_float3 &p2 = t2.a;
502
+ const uv_float3 &q2 = t2.b;
503
+ const uv_float3 &r2 = t2.c;
504
+
505
+ if (orient2D(r2, p2, q1) >= 0) { // I
506
+ if (orient2D(q2, r2, q1) >= 0) { // II.a
507
+ if (orient2D(p1, p2, q1) >= 0) { // III.a
508
+ if (orient2D(p1, q2, q1) <= 0) { // IV.a
509
+ return true;
510
+ }
511
+ } else {
512
+ if (orient2D(p1, p2, r1) >= 0) { // IV.b
513
+ if (orient2D(r2, p2, r1) <= 0) { // V.a
514
+ return true;
515
+ }
516
+ }
517
+ }
518
+ } else {
519
+ if (orient2D(p1, q2, q1) <= 0) { // III.b
520
+ if (orient2D(q2, r2, r1) >= 0) { // IV.c
521
+ if (orient2D(q1, r1, q2) >= 0) { // V.b
522
+ return true;
523
+ }
524
+ }
525
+ }
526
+ }
527
+ } else {
528
+ if (orient2D(r2, p2, r1) >= 0) { // II.b
529
+ if (orient2D(q1, r1, r2) >= 0) { // III.c
530
+ if (orient2D(r1, p1, p2) >= 0) { // IV.d
531
+ return true;
532
+ }
533
+ } else {
534
+ if (orient2D(q1, r1, q2) >= 0) { // IV.e
535
+ if (orient2D(q2, r2, r1) >= 0) { // V.c
536
+ return true;
537
+ }
538
+ }
539
+ }
540
+ }
541
+ }
542
+
543
+ return false;
544
+ }
545
+
546
+ bool coplanarIntersect(Triangle &t1, Triangle &t2,
547
+ std::vector<uv_float3> *target = nullptr) {
548
+ uv_float3 normal, u, v;
549
+ t1.getNormal(normal);
550
+ normal = normalize(normal);
551
+ u = normalize(t1.a - t1.b);
552
+ v = cross(normal, u);
553
+
554
+ // Move basis to t1.a
555
+ u = u + t1.a;
556
+ v = v + t1.a;
557
+ normal = normal + t1.a;
558
+
559
+ Matrix4 _matrix;
560
+ _matrix.set(t1.a.x, u.x, v.x, normal.x, t1.a.y, u.y, v.y, normal.y, t1.a.z,
561
+ u.z, v.z, normal.z, 1, 1, 1, 1);
562
+
563
+ Matrix4 _affineMatrix;
564
+ _affineMatrix.set(0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1);
565
+
566
+ _matrix.invert(); // Invert the _matrix
567
+ _matrix = _affineMatrix * _matrix;
568
+
569
+ // Apply transformation
570
+ apply_matrix4(t1.a, _matrix);
571
+ apply_matrix4(t1.b, _matrix);
572
+ apply_matrix4(t1.c, _matrix);
573
+ apply_matrix4(t2.a, _matrix);
574
+ apply_matrix4(t2.b, _matrix);
575
+ apply_matrix4(t2.c, _matrix);
576
+
577
+ makeTriCounterClockwise(t1);
578
+ makeTriCounterClockwise(t2);
579
+
580
+ const uv_float3 &p1 = t1.a;
581
+ const uv_float3 &p2 = t2.a;
582
+ const uv_float3 &q2 = t2.b;
583
+ const uv_float3 &r2 = t2.c;
584
+
585
+ int o_p2q2 = orient2D(p2, q2, p1);
586
+ int o_q2r2 = orient2D(q2, r2, p1);
587
+ int o_r2p2 = orient2D(r2, p2, p1);
588
+
589
+ bool intersecting = false;
590
+ if (o_p2q2 >= 0) {
591
+ if (o_q2r2 >= 0) {
592
+ if (o_r2p2 >= 0) {
593
+ // + + +
594
+ intersecting = true;
595
+ } else {
596
+ // + + -
597
+ intersecting = intersectionTypeR1(t1, t2);
598
+ }
599
+ } else {
600
+ if (o_r2p2 >= 0) {
601
+ // + - +
602
+ permuteTriRight(t2);
603
+ intersecting = intersectionTypeR1(t1, t2);
604
+ } else {
605
+ // + - -
606
+ intersecting = intersectionTypeR2(t1, t2);
607
+ }
608
+ }
609
+ } else {
610
+ if (o_q2r2 >= 0) {
611
+ if (o_r2p2 >= 0) {
612
+ // - + +
613
+ permuteTriLeft(t2);
614
+ intersecting = intersectionTypeR1(t1, t2);
615
+ } else {
616
+ // - + -
617
+ permuteTriLeft(t2);
618
+ intersecting = intersectionTypeR2(t1, t2);
619
+ }
620
+ } else {
621
+ if (o_r2p2 >= 0) {
622
+ // - - +
623
+ permuteTriRight(t2);
624
+ intersecting = intersectionTypeR2(t1, t2);
625
+ } else {
626
+ // - - -
627
+ std::cerr << "Triangles should not be flat." << std::endl;
628
+ return false;
629
+ }
630
+ }
631
+ }
632
+
633
+ if (intersecting && target) {
634
+ clipTriangle(t1, t2, *target);
635
+
636
+ _matrix.invert();
637
+ // Apply the transform to each target point
638
+ for (int i = 0; i < target->size(); ++i) {
639
+ apply_matrix4(target->at(i), _matrix);
640
+ }
641
+ }
642
+
643
+ return intersecting;
644
+ }
645
+
646
+ // Helper function to calculate the area of a polygon
647
+ float polygon_area(const std::vector<uv_float3> &polygon) {
648
+ if (polygon.size() < 3)
649
+ return 0.0f; // Not a polygon
650
+
651
+ uv_float3 normal = {0.0f, 0.0f, 0.0f}; // Initialize normal vector
652
+
653
+ // Calculate the cross product of edges around the polygon
654
+ for (size_t i = 0; i < polygon.size(); ++i) {
655
+ uv_float3 p1 = polygon[i];
656
+ uv_float3 p2 = polygon[(i + 1) % polygon.size()];
657
+
658
+ normal = normal + cross(p1, p2); // Accumulate the normal vector
659
+ }
660
+
661
+ float area =
662
+ magnitude(normal) / 2.0f; // Area is half the magnitude of the normal
663
+ return area;
664
+ }
665
+
666
+ bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
667
+ uv_float2 p2, uv_float2 q2, uv_float2 r2) {
668
+ Triangle t1(p1, q1, r1);
669
+ Triangle t2(p2, q2, r2);
670
+
671
+ if (isTriDegenerated(t1) || isTriDegenerated(t2)) {
672
+ // std::cerr << "Degenerated triangles provided, skipping." << std::endl;
673
+ return false;
674
+ }
675
+
676
+ int o1a = orient3D(t2.a, t2.b, t2.c, t1.a);
677
+ int o1b = orient3D(t2.a, t2.b, t2.c, t1.b);
678
+ int o1c = orient3D(t2.a, t2.b, t2.c, t1.c);
679
+
680
+ std::vector<uv_float3> intersections;
681
+ bool intersects;
682
+
683
+ if (o1a == o1b && o1a == o1c) // [[likely]]
684
+ {
685
+ intersects = o1a == 0 && coplanarIntersect(t1, t2, &intersections);
686
+ } else // [[unlikely]]
687
+ {
688
+ intersects = crossIntersect(t1, t2, o1a, o1b, o1c, &intersections);
689
+ }
690
+
691
+ if (intersects) {
692
+ float area = polygon_area(intersections);
693
+
694
+ // std::cout << "Intersection area: " << area << std::endl;
695
+ if (area < 1e-10f || std::isfinite(area) == false) {
696
+ // std::cout<<"Invalid area: " << area << std::endl;
697
+ return false; // Ignore intersection if the area is too small
698
+ }
699
+ }
700
+
701
+ return intersects;
702
+ }
uv_unwrapper/uv_unwrapper/csrc/intersect.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "common.h"
4
+ #include <vector>
5
+
6
+ bool triangle_aabb_intersection(const uv_float2 &aabb_min,
7
+ const uv_float2 &aabb_max, const uv_float2 &v0,
8
+ const uv_float2 &v1, const uv_float2 &v2);
9
+ bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
10
+ uv_float2 p2, uv_float2 q2, uv_float2 r2);
uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "bvh.h"
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/Context.h>
4
+ #include <chrono>
5
+ #include <cmath>
6
+ #include <cstring>
7
+ #include <omp.h>
8
+ #include <set>
9
+ #include <torch/extension.h>
10
+ #include <vector>
11
+
12
+ // #define TIMING
13
+
14
+ #if defined(_MSC_VER)
15
+ #include <BaseTsd.h>
16
+ typedef SSIZE_T ssize_t;
17
+ #endif
18
+
19
+ namespace UVUnwrapper {
20
+ void create_bvhs(BVH *bvhs, Triangle *triangles,
21
+ std::vector<std::set<int>> &triangle_per_face, int num_faces,
22
+ int start, int end) {
23
+ #pragma omp parallel for
24
+ for (int i = start; i < end; i++) {
25
+ int num_triangles = triangle_per_face[i].size();
26
+ Triangle *triangles_per_face = new Triangle[num_triangles];
27
+ int *indices = new int[num_triangles];
28
+ int j = 0;
29
+ for (int idx : triangle_per_face[i]) {
30
+ triangles_per_face[j] = triangles[idx];
31
+ indices[j++] = idx;
32
+ }
33
+ // Each thread writes to it's own memory space
34
+ // First check if the number of triangles is 0
35
+ if (num_triangles == 0) {
36
+ bvhs[i - start] = std::move(BVH()); // Default constructor
37
+ } else {
38
+ bvhs[i - start] = std::move(
39
+ BVH(triangles_per_face, indices,
40
+ num_triangles)); // BVH now handles memory of triangles_per_face
41
+ }
42
+ delete[] triangles_per_face;
43
+ }
44
+ }
45
+
46
+ void perform_intersection_check(BVH *bvhs, int num_bvhs, Triangle *triangles,
47
+ uv_float3 *vertex_tri_centroids,
48
+ int64_t *assign_indices_ptr,
49
+ ssize_t num_indices, int offset,
50
+ std::vector<std::set<int>> &triangle_per_face) {
51
+ std::vector<std::pair<int, int>>
52
+ unique_intersections; // Store unique intersections as pairs of triangle
53
+ // indices
54
+
55
+ // Step 1: Detect intersections in parallel
56
+ #pragma omp parallel for
57
+ for (int i = 0; i < num_indices; i++) {
58
+ if (assign_indices_ptr[i] < offset) {
59
+ continue;
60
+ }
61
+
62
+ Triangle cur_tri = triangles[i];
63
+ auto &cur_bvh = bvhs[assign_indices_ptr[i] - offset];
64
+
65
+ if (cur_bvh.bvhNode == nullptr) {
66
+ continue;
67
+ }
68
+
69
+ std::vector<int> intersections = cur_bvh.Intersect(cur_tri);
70
+
71
+ if (!intersections.empty()) {
72
+
73
+ #pragma omp critical
74
+ {
75
+ for (int intersect : intersections) {
76
+ if (i != intersect) {
77
+ // Ensure we only store unique pairs (A, B) where A < B to avoid
78
+ // duplication
79
+ if (i < intersect) {
80
+ unique_intersections.push_back(std::make_pair(i, intersect));
81
+ } else {
82
+ unique_intersections.push_back(std::make_pair(intersect, i));
83
+ }
84
+ }
85
+ }
86
+ }
87
+ }
88
+ }
89
+
90
+ // Step 2: Process unique intersections
91
+ for (int idx = 0; idx < unique_intersections.size(); idx++) {
92
+ int first = unique_intersections[idx].first;
93
+ int second = unique_intersections[idx].second;
94
+
95
+ int i_idx = assign_indices_ptr[first];
96
+
97
+ int norm_idx = i_idx % 6;
98
+ int axis = (norm_idx < 2) ? 0 : (norm_idx < 4) ? 1 : 2;
99
+ bool use_max = (i_idx % 2) == 1;
100
+
101
+ float pos_a = vertex_tri_centroids[first][axis];
102
+ float pos_b = vertex_tri_centroids[second][axis];
103
+ // Sort the intersections based on vertex_tri_centroids along the specified
104
+ // axis
105
+ if (use_max) {
106
+ if (pos_a < pos_b) {
107
+ std::swap(first, second);
108
+ }
109
+ } else {
110
+ if (pos_a > pos_b) {
111
+ std::swap(first, second);
112
+ }
113
+ }
114
+
115
+ // Update the unique intersections
116
+ unique_intersections[idx].first = first;
117
+ unique_intersections[idx].second = second;
118
+ }
119
+
120
+ // Now only get the second intersections from the pair and put them in a set
121
+ // The second intersection should always be the occluded triangle
122
+ std::set<int> second_intersections;
123
+ for (int idx = 0; idx < (int)unique_intersections.size(); idx++) {
124
+ int second = unique_intersections[idx].second;
125
+ second_intersections.insert(second);
126
+ }
127
+
128
+ for (int int_idx : second_intersections) {
129
+ // Move the second (occluded) triangle by 6
130
+ int intersect_idx = assign_indices_ptr[int_idx];
131
+ int new_index = intersect_idx + 6;
132
+ new_index = std::clamp(new_index, 0, 12);
133
+
134
+ assign_indices_ptr[int_idx] = new_index;
135
+ triangle_per_face[intersect_idx].erase(int_idx);
136
+ triangle_per_face[new_index].insert(int_idx);
137
+ }
138
+ }
139
+
140
+ torch::Tensor assign_faces_uv_to_atlas_index(torch::Tensor vertices,
141
+ torch::Tensor indices,
142
+ torch::Tensor face_uv,
143
+ torch::Tensor face_index) {
144
+ // Get the number of faces
145
+ int num_faces = indices.size(0);
146
+ torch::Tensor assign_indices =
147
+ torch::empty(
148
+ {
149
+ num_faces,
150
+ },
151
+ torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU))
152
+ .contiguous();
153
+
154
+ auto vert_accessor = vertices.accessor<float, 2>();
155
+ auto indices_accessor = indices.accessor<int64_t, 2>();
156
+ auto face_uv_accessor = face_uv.accessor<float, 2>();
157
+
158
+ const int64_t *face_index_ptr = face_index.contiguous().data_ptr<int64_t>();
159
+ int64_t *assign_indices_ptr = assign_indices.data_ptr<int64_t>();
160
+ // copy face_index to assign_indices
161
+ memcpy(assign_indices_ptr, face_index_ptr, num_faces * sizeof(int64_t));
162
+
163
+ #ifdef TIMING
164
+ auto start = std::chrono::high_resolution_clock::now();
165
+ #endif
166
+ uv_float3 *vertex_tri_centroids = new uv_float3[num_faces];
167
+ Triangle *triangles = new Triangle[num_faces];
168
+
169
+ // Use std::set to store triangles for each face
170
+ std::vector<std::set<int>> triangle_per_face;
171
+ triangle_per_face.resize(13);
172
+
173
+ #pragma omp parallel for
174
+ for (int i = 0; i < num_faces; i++) {
175
+ int face_idx = i * 3;
176
+ triangles[i].v0 = {face_uv_accessor[face_idx + 0][0],
177
+ face_uv_accessor[face_idx + 0][1]};
178
+ triangles[i].v1 = {face_uv_accessor[face_idx + 1][0],
179
+ face_uv_accessor[face_idx + 1][1]};
180
+ triangles[i].v2 = {face_uv_accessor[face_idx + 2][0],
181
+ face_uv_accessor[face_idx + 2][1]};
182
+ triangles[i].centroid =
183
+ triangle_centroid(triangles[i].v0, triangles[i].v1, triangles[i].v2);
184
+
185
+ uv_float3 v0 = {vert_accessor[indices_accessor[i][0]][0],
186
+ vert_accessor[indices_accessor[i][0]][1],
187
+ vert_accessor[indices_accessor[i][0]][2]};
188
+ uv_float3 v1 = {vert_accessor[indices_accessor[i][1]][0],
189
+ vert_accessor[indices_accessor[i][1]][1],
190
+ vert_accessor[indices_accessor[i][1]][2]};
191
+ uv_float3 v2 = {vert_accessor[indices_accessor[i][2]][0],
192
+ vert_accessor[indices_accessor[i][2]][1],
193
+ vert_accessor[indices_accessor[i][2]][2]};
194
+ vertex_tri_centroids[i] = triangle_centroid(v0, v1, v2);
195
+
196
+ // Assign the triangle to the face index
197
+ #pragma omp critical
198
+ { triangle_per_face[face_index_ptr[i]].insert(i); }
199
+ }
200
+
201
+ #ifdef TIMING
202
+ auto start_bvh = std::chrono::high_resolution_clock::now();
203
+ #endif
204
+
205
+ BVH *bvhs = new BVH[6];
206
+ create_bvhs(bvhs, triangles, triangle_per_face, num_faces, 0, 6);
207
+
208
+ #ifdef TIMING
209
+ auto end_bvh = std::chrono::high_resolution_clock::now();
210
+ std::chrono::duration<double> elapsed_seconds = end_bvh - start_bvh;
211
+ std::cout << "BVH build time: " << elapsed_seconds.count() << "s\n";
212
+
213
+ auto start_intersection_1 = std::chrono::high_resolution_clock::now();
214
+ #endif
215
+
216
+ perform_intersection_check(bvhs, 6, triangles, vertex_tri_centroids,
217
+ assign_indices_ptr, num_faces, 0,
218
+ triangle_per_face);
219
+
220
+ #ifdef TIMING
221
+ auto end_intersection_1 = std::chrono::high_resolution_clock::now();
222
+ elapsed_seconds = end_intersection_1 - start_intersection_1;
223
+ std::cout << "Intersection 1 time: " << elapsed_seconds.count() << "s\n";
224
+ #endif
225
+ // Create 6 new bvhs and delete the old ones
226
+ BVH *new_bvhs = new BVH[6];
227
+ create_bvhs(new_bvhs, triangles, triangle_per_face, num_faces, 6, 12);
228
+
229
+ #ifdef TIMING
230
+ auto end_bvh2 = std::chrono::high_resolution_clock::now();
231
+ elapsed_seconds = end_bvh2 - end_intersection_1;
232
+ std::cout << "BVH 2 build time: " << elapsed_seconds.count() << "s\n";
233
+ auto start_intersection_2 = std::chrono::high_resolution_clock::now();
234
+ #endif
235
+
236
+ perform_intersection_check(new_bvhs, 6, triangles, vertex_tri_centroids,
237
+ assign_indices_ptr, num_faces, 6,
238
+ triangle_per_face);
239
+
240
+ #ifdef TIMING
241
+ auto end_intersection_2 = std::chrono::high_resolution_clock::now();
242
+ elapsed_seconds = end_intersection_2 - start_intersection_2;
243
+ std::cout << "Intersection 2 time: " << elapsed_seconds.count() << "s\n";
244
+ elapsed_seconds = end_intersection_2 - start;
245
+ std::cout << "Total time: " << elapsed_seconds.count() << "s\n";
246
+ #endif
247
+
248
+ // Cleanup
249
+ delete[] vertex_tri_centroids;
250
+ delete[] triangles;
251
+ delete[] bvhs;
252
+ delete[] new_bvhs;
253
+
254
+ return assign_indices;
255
+ }
256
+
257
+ // Registers _C as a Python extension module.
258
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
259
+
260
+ // Defines the operators
261
+ TORCH_LIBRARY(UVUnwrapper, m) {
262
+ m.def("assign_faces_uv_to_atlas_index(Tensor vertices, Tensor indices, "
263
+ "Tensor face_uv, Tensor face_index) -> Tensor");
264
+ }
265
+
266
+ // Registers CPP implementations
267
+ TORCH_LIBRARY_IMPL(UVUnwrapper, CPU, m) {
268
+ m.impl("assign_faces_uv_to_atlas_index", &assign_faces_uv_to_atlas_index);
269
+ }
270
+
271
+ } // namespace UVUnwrapper
uv_unwrapper/uv_unwrapper/unwrap.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+
9
+
10
+ class Unwrapper(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def _box_assign_vertex_to_cube_face(
15
+ self,
16
+ vertex_positions: Tensor,
17
+ vertex_normals: Tensor,
18
+ triangle_idxs: Tensor,
19
+ bbox: Tensor,
20
+ ) -> Tuple[Tensor, Tensor]:
21
+ """
22
+ Assigns each vertex to a cube face based on the face normal
23
+
24
+ Args:
25
+ vertex_positions (Tensor, Nv 3, float): Vertex positions
26
+ vertex_normals (Tensor, Nv 3, float): Vertex normals
27
+ triangle_idxs (Tensor, Nf 3, int): Triangle indices
28
+ bbox (Tensor, 2 3, float): Bounding box of the mesh
29
+
30
+ Returns:
31
+ Tensor, Nf 3 2, float: UV coordinates
32
+ Tensor, Nf, int: Cube face indices
33
+ """
34
+
35
+ # Test to not have a scaled model to fit the space better
36
+ # bbox_min = bbox[:1].mean(-1, keepdim=True)
37
+ # bbox_max = bbox[1:].mean(-1, keepdim=True)
38
+ # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
39
+
40
+ # Create a [0, 1] normalized vertex position
41
+ v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
42
+ # And to [-1, 1]
43
+ v_pos_normalized = 2.0 * v_pos_normalized - 1.0
44
+
45
+ # Get all vertex positions for each triangle
46
+ # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
47
+ v0 = v_pos_normalized[triangle_idxs[:, 0]]
48
+ v1 = v_pos_normalized[triangle_idxs[:, 1]]
49
+ v2 = v_pos_normalized[triangle_idxs[:, 2]]
50
+ tri_stack = torch.stack([v0, v1, v2], dim=1)
51
+
52
+ vn0 = vertex_normals[triangle_idxs[:, 0]]
53
+ vn1 = vertex_normals[triangle_idxs[:, 1]]
54
+ vn2 = vertex_normals[triangle_idxs[:, 2]]
55
+ tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
56
+
57
+ # Just average the normals per face
58
+ face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
59
+
60
+ # Now decide based on the face normal in which box map we project
61
+ # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
62
+ abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
63
+
64
+ axis = torch.tensor(
65
+ [
66
+ [1, 0, 0], # 0
67
+ [-1, 0, 0], # 1
68
+ [0, 1, 0], # 2
69
+ [0, -1, 0], # 3
70
+ [0, 0, 1], # 4
71
+ [0, 0, -1], # 5
72
+ ],
73
+ device=face_normal.device,
74
+ dtype=face_normal.dtype,
75
+ )
76
+ face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
77
+ index = face_normal_axis.argmax(-1)
78
+
79
+ max_axis, uc, vc = (
80
+ torch.ones_like(abs_x),
81
+ torch.zeros_like(tri_stack[..., :1]),
82
+ torch.zeros_like(tri_stack[..., :1]),
83
+ )
84
+ mask_pos_x = index == 0
85
+ max_axis[mask_pos_x] = abs_x[mask_pos_x]
86
+ uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
87
+ vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
88
+
89
+ mask_neg_x = index == 1
90
+ max_axis[mask_neg_x] = abs_x[mask_neg_x]
91
+ uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
92
+ vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
93
+
94
+ mask_pos_y = index == 2
95
+ max_axis[mask_pos_y] = abs_y[mask_pos_y]
96
+ uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
97
+ vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
98
+
99
+ mask_neg_y = index == 3
100
+ max_axis[mask_neg_y] = abs_y[mask_neg_y]
101
+ uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
102
+ vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
103
+
104
+ mask_pos_z = index == 4
105
+ max_axis[mask_pos_z] = abs_z[mask_pos_z]
106
+ uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
107
+ vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
108
+
109
+ mask_neg_z = index == 5
110
+ max_axis[mask_neg_z] = abs_z[mask_neg_z]
111
+ uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
112
+ vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
113
+
114
+ # UC from [-1, 1] to [0, 1]
115
+ max_dim_div = max_axis.max(dim=0, keepdim=True).values
116
+ uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
117
+ vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
118
+
119
+ uv = torch.stack([uc, vc], dim=-1)
120
+
121
+ return uv, index
122
+
123
+ def _assign_faces_uv_to_atlas_index(
124
+ self,
125
+ vertex_positions: Tensor,
126
+ triangle_idxs: Tensor,
127
+ face_uv: Tensor,
128
+ face_index: Tensor,
129
+ ) -> Tensor: # noqa: F821
130
+ """
131
+ Assigns the face UV to the atlas index
132
+
133
+ Args:
134
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
135
+ triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
136
+ face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
137
+ face_index (Integer[Tensor, "Nf"]): Face indices
138
+
139
+ Returns:
140
+ Integer[Tensor, "Nf"]: Atlas index
141
+ """
142
+ return torch.ops.UVUnwrapper.assign_faces_uv_to_atlas_index(
143
+ vertex_positions.cpu(),
144
+ triangle_idxs.cpu(),
145
+ face_uv.view(-1, 2).cpu(),
146
+ face_index.cpu(),
147
+ ).to(vertex_positions.device)
148
+
149
+ def _find_slice_offset_and_scale(
150
+ self, index: Tensor
151
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # noqa: F821
152
+ """
153
+ Find the slice offset and scale
154
+
155
+ Args:
156
+ index (Integer[Tensor, "Nf"]): Atlas index
157
+
158
+ Returns:
159
+ Float[Tensor, "Nf"]: Offset x
160
+ Float[Tensor, "Nf"]: Offset y
161
+ Float[Tensor, "Nf"]: Division x
162
+ Float[Tensor, "Nf"]: Division y
163
+ """
164
+
165
+ # 6 due to the 6 cube faces
166
+ off = 1 / 3
167
+ dupl_off = 1 / 6
168
+
169
+ # Here, we need to decide how to pack the textures in the case of overlap
170
+ def x_offset_calc(x, i):
171
+ offset_calc = i // 6
172
+ # Initial coordinates - just 3x2 grid
173
+ if offset_calc == 0:
174
+ return off * x
175
+ else:
176
+ # Smaller 3x2 grid plus eventual shift to right for
177
+ # second overlap
178
+ return dupl_off * x + min(offset_calc - 1, 1) * 0.5
179
+
180
+ def y_offset_calc(x, i):
181
+ offset_calc = i // 6
182
+ # Initial coordinates - just a 3x2 grid
183
+ if offset_calc == 0:
184
+ return off * x
185
+ else:
186
+ # Smaller coordinates in the lowest row
187
+ return dupl_off * x + off * 2
188
+
189
+ offset_x = torch.zeros_like(index, dtype=torch.float32)
190
+ offset_y = torch.zeros_like(index, dtype=torch.float32)
191
+ offset_x_vals = [0, 1, 2, 0, 1, 2]
192
+ offset_y_vals = [0, 0, 0, 1, 1, 1]
193
+ for i in range(index.max().item() + 1):
194
+ mask = index == i
195
+ if not mask.any():
196
+ continue
197
+ offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
198
+ offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
199
+
200
+ div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
201
+ # All overlap elements are saved in half scale
202
+ div_x[index >= 6] = 6
203
+ div_y = div_x.clone() # Same for y
204
+ # Except for the random overlaps
205
+ div_x[index >= 12] = 2
206
+ # But the random overlaps are saved in a large block in the lower thirds
207
+ div_y[index >= 12] = 3
208
+
209
+ return offset_x, offset_y, div_x, div_y
210
+
211
+ def _calculate_tangents(
212
+ self,
213
+ vertex_positions: Tensor,
214
+ vertex_normals: Tensor,
215
+ triangle_idxs: Tensor,
216
+ face_uv: Tensor,
217
+ ) -> Tensor:
218
+ """
219
+ Calculate the tangents for each triangle
220
+
221
+ Args:
222
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
223
+ vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
224
+ triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
225
+ face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
226
+
227
+ Returns:
228
+ Float[Tensor, "Nf 3 4"]: Tangents
229
+ """
230
+ vn_idx = [None] * 3
231
+ pos = [None] * 3
232
+ tex = face_uv.unbind(1)
233
+ for i in range(0, 3):
234
+ pos[i] = vertex_positions[triangle_idxs[:, i]]
235
+ # t_nrm_idx is always the same as t_pos_idx
236
+ vn_idx[i] = triangle_idxs[:, i]
237
+
238
+ if torch.backends.mps.is_available():
239
+ tangents = torch.zeros_like(vertex_normals).contiguous()
240
+ tansum = torch.zeros_like(vertex_normals).contiguous()
241
+ else:
242
+ tangents = torch.zeros_like(vertex_normals)
243
+ tansum = torch.zeros_like(vertex_normals)
244
+
245
+ # Compute tangent space for each triangle
246
+ duv1 = tex[1] - tex[0]
247
+ duv2 = tex[2] - tex[0]
248
+ dpos1 = pos[1] - pos[0]
249
+ dpos2 = pos[2] - pos[0]
250
+
251
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
252
+
253
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
254
+
255
+ # Avoid division by zero for degenerated texture coordinates
256
+ denom_safe = denom.clip(1e-6)
257
+ tang = tng_nom / denom_safe
258
+
259
+ # Update all 3 vertices
260
+ for i in range(0, 3):
261
+ idx = vn_idx[i][:, None].repeat(1, 3)
262
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
263
+ tansum.scatter_add_(
264
+ 0, idx, torch.ones_like(tang)
265
+ ) # tansum[n_i] = tansum[n_i] + 1
266
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
267
+ # triangles influence the tangent space more
268
+ tangents = tangents / tansum
269
+
270
+ # Normalize and make sure tangent is perpendicular to normal
271
+ tangents = F.normalize(tangents, dim=1)
272
+ tangents = F.normalize(
273
+ tangents
274
+ - (tangents * vertex_normals).sum(-1, keepdim=True) * vertex_normals
275
+ )
276
+
277
+ return tangents
278
+
279
+ def _rotate_uv_slices_consistent_space(
280
+ self,
281
+ vertex_positions: Tensor,
282
+ vertex_normals: Tensor,
283
+ triangle_idxs: Tensor,
284
+ uv: Tensor,
285
+ index: Tensor,
286
+ ) -> Tensor:
287
+ """
288
+ Rotate the UV slices so they are in a consistent space
289
+
290
+ Args:
291
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
292
+ vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
293
+ triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
294
+ uv (Float[Tensor, "Nf 3 2"]): UV coordinates
295
+ index (Integer[Tensor, "Nf"]): Atlas index
296
+
297
+ Returns:
298
+ Float[Tensor, "Nf 3 2"]: Rotated UV coordinates
299
+ """
300
+
301
+ tangents = self._calculate_tangents(
302
+ vertex_positions, vertex_normals, triangle_idxs, uv
303
+ )
304
+ pos_stack = torch.stack(
305
+ [
306
+ -vertex_positions[..., 1],
307
+ vertex_positions[..., 0],
308
+ torch.zeros_like(vertex_positions[..., 0]),
309
+ ],
310
+ dim=-1,
311
+ )
312
+ expected_tangents = F.normalize(
313
+ torch.linalg.cross(
314
+ vertex_normals,
315
+ torch.linalg.cross(pos_stack, vertex_normals, dim=-1),
316
+ dim=-1,
317
+ ),
318
+ -1,
319
+ )
320
+
321
+ actual_tangents = tangents[triangle_idxs]
322
+ expected_tangents = expected_tangents[triangle_idxs]
323
+
324
+ def rotation_matrix_2d(theta):
325
+ c, s = torch.cos(theta), torch.sin(theta)
326
+ return torch.tensor([[c, -s], [s, c]])
327
+
328
+ # Now find the rotation
329
+ index_mod = index % 6 # Shouldn't happen. Just for safety
330
+ for i in range(6):
331
+ mask = index_mod == i
332
+ if not mask.any():
333
+ continue
334
+
335
+ actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
336
+ expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
337
+
338
+ dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
339
+ cross_product = (
340
+ actual_mean_tangent[0] * expected_mean_tangent[1]
341
+ - actual_mean_tangent[1] * expected_mean_tangent[0]
342
+ )
343
+ angle = torch.atan2(cross_product, dot_product)
344
+
345
+ rot_matrix = rotation_matrix_2d(angle).to(mask.device)
346
+ # Center the uv coordinate to be in the range of -1 to 1 and 0 centered
347
+ uv_cur = uv[mask] * 2 - 1 # Center it first
348
+ # Rotate it
349
+ uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
350
+
351
+ # Rescale uv[mask] to be within the 0-1 range
352
+ uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
353
+
354
+ return uv
355
+
356
+ def _handle_slice_uvs(
357
+ self,
358
+ uv: Tensor,
359
+ index: Tensor, # noqa: F821
360
+ island_padding: float,
361
+ max_index: int = 6 * 2,
362
+ ) -> Tensor: # noqa: F821
363
+ """
364
+ Handle the slice UVs
365
+
366
+ Args:
367
+ uv (Float[Tensor, "Nf 3 2"]): UV coordinates
368
+ index (Integer[Tensor, "Nf"]): Atlas index
369
+ island_padding (float): Island padding
370
+ max_index (int): Maximum index
371
+
372
+ Returns:
373
+ Float[Tensor, "Nf 3 2"]: Updated UV coordinates
374
+
375
+ """
376
+ uc, vc = uv.unbind(-1)
377
+
378
+ # Get the second slice (The first overlap)
379
+ index_filter = [index == i for i in range(6, max_index)]
380
+
381
+ # Normalize them to always fully fill the atlas patch
382
+ for i, fi in enumerate(index_filter):
383
+ if fi.sum() > 0:
384
+ # Scale the slice but only up to a factor of 2
385
+ # This keeps the texture resolution with the first slice in line (Half space in UV)
386
+ uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(
387
+ 0.5
388
+ )
389
+ vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(
390
+ 0.5
391
+ )
392
+
393
+ uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
394
+ vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
395
+
396
+ return torch.stack([uc_padded, vc_padded], dim=-1)
397
+
398
+ def _handle_remaining_uvs(
399
+ self,
400
+ uv: Tensor,
401
+ index: Tensor, # noqa: F821
402
+ island_padding: float,
403
+ ) -> Tensor:
404
+ """
405
+ Handle the remaining UVs (The ones that are not slices)
406
+
407
+ Args:
408
+ uv (Float[Tensor, "Nf 3 2"]): UV coordinates
409
+ index (Integer[Tensor, "Nf"]): Atlas index
410
+ island_padding (float): Island padding
411
+
412
+ Returns:
413
+ Float[Tensor, "Nf 3 2"]: Updated UV coordinates
414
+ """
415
+ uc, vc = uv.unbind(-1)
416
+ # Get all remaining elements
417
+ remaining_filter = index >= 6 * 2
418
+ squares_left = remaining_filter.sum()
419
+
420
+ if squares_left == 0:
421
+ return uv
422
+
423
+ uc = uc[remaining_filter]
424
+ vc = vc[remaining_filter]
425
+
426
+ # Or remaining triangles are distributed in a rectangle
427
+ # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
428
+ ratio = 0.5 * (1 / 3) # 1.5
429
+ # sqrt(744/(0.5*(1/3)))
430
+
431
+ mult = math.sqrt(squares_left / ratio)
432
+ num_square_width = int(math.ceil(0.5 * mult))
433
+ num_square_height = int(math.ceil(squares_left / num_square_width))
434
+
435
+ width = 1 / num_square_width
436
+ height = 1 / num_square_height
437
+
438
+ # The idea is again to keep the texture resolution consistent with the first slice
439
+ # This only occupys half the region in the texture chart but the scaling on the squares
440
+ # assumes full coverage.
441
+ clip_val = min(width, height) * 1.5
442
+ # Now normalize the UVs with taking into account the maximum scaling
443
+ uc = (uc - uc.min(dim=1, keepdim=True).values) / (
444
+ uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
445
+ ).clip(clip_val)
446
+ vc = (vc - vc.min(dim=1, keepdim=True).values) / (
447
+ vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
448
+ ).clip(clip_val)
449
+ # Add a small padding
450
+ uc = (
451
+ uc * (1 - island_padding * num_square_width * 0.5)
452
+ + island_padding * num_square_width * 0.25
453
+ ).clip(0, 1)
454
+ vc = (
455
+ vc * (1 - island_padding * num_square_height * 0.5)
456
+ + island_padding * num_square_height * 0.25
457
+ ).clip(0, 1)
458
+
459
+ uc = uc * width
460
+ vc = vc * height
461
+
462
+ # And calculate offsets for each element
463
+ idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
464
+ x_idx = idx % num_square_width
465
+ y_idx = idx // num_square_width
466
+ # And move each triangle to its own spot
467
+ uc = uc + x_idx[:, None] * width
468
+ vc = vc + y_idx[:, None] * height
469
+
470
+ uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
471
+ vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
472
+
473
+ uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
474
+
475
+ return uv
476
+
477
+ def _distribute_individual_uvs_in_atlas(
478
+ self,
479
+ face_uv: Tensor,
480
+ assigned_faces: Tensor,
481
+ offset_x: Tensor,
482
+ offset_y: Tensor,
483
+ div_x: Tensor,
484
+ div_y: Tensor,
485
+ island_padding: float,
486
+ ) -> Tensor:
487
+ """
488
+ Distribute the individual UVs in the atlas
489
+
490
+ Args:
491
+ face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
492
+ assigned_faces (Integer[Tensor, "Nf"]): Assigned faces
493
+ offset_x (Float[Tensor, "Nf"]): Offset x
494
+ offset_y (Float[Tensor, "Nf"]): Offset y
495
+ div_x (Float[Tensor, "Nf"]): Division x
496
+ div_y (Float[Tensor, "Nf"]): Division y
497
+ island_padding (float): Island padding
498
+
499
+ Returns:
500
+ Float[Tensor, "Nf 3 2"]: Updated UV coordinates
501
+ """
502
+ # Place the slice first
503
+ placed_uv = self._handle_slice_uvs(face_uv, assigned_faces, island_padding)
504
+ # Then handle the remaining overlap elements
505
+ placed_uv = self._handle_remaining_uvs(
506
+ placed_uv, assigned_faces, island_padding
507
+ )
508
+
509
+ uc, vc = placed_uv.unbind(-1)
510
+ uc = uc / div_x[:, None] + offset_x[:, None]
511
+ vc = vc / div_y[:, None] + offset_y[:, None]
512
+
513
+ uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
514
+
515
+ return uv
516
+
517
+ def _get_unique_face_uv(
518
+ self,
519
+ uv: Tensor,
520
+ ) -> Tuple[Tensor, Tensor]:
521
+ """
522
+ Get the unique face UV
523
+
524
+ Args:
525
+ uv (Float[Tensor, "Nf 3 2"]): UV coordinates
526
+
527
+ Returns:
528
+ Float[Tensor, "Utex 3"]: Unique UV coordinates
529
+ Integer[Tensor, "Nf"]: Vertex index
530
+ """
531
+ unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
532
+ # And add the face to uv index mapping
533
+ vtex_idx = unique_idx.view(-1, 3)
534
+
535
+ return unique_uv, vtex_idx
536
+
537
+ def _align_mesh_with_main_axis(
538
+ self, vertex_positions: Tensor, vertex_normals: Tensor
539
+ ) -> Tuple[Tensor, Tensor]:
540
+ """
541
+ Align the mesh with the main axis
542
+
543
+ Args:
544
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
545
+ vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
546
+
547
+ Returns:
548
+ Float[Tensor, "Nv 3"]: Rotated vertex positions
549
+ Float[Tensor, "Nv 3"]: Rotated vertex normals
550
+ """
551
+
552
+ # Use pca to find the 2 main axis (third is derived by cross product)
553
+ # Set the random seed so it's repeatable
554
+ torch.manual_seed(0)
555
+ _, _, v = torch.pca_lowrank(vertex_positions, q=2)
556
+ main_axis, seconday_axis = v[:, 0], v[:, 1]
557
+
558
+ main_axis = F.normalize(main_axis, eps=1e-6, dim=-1) # 3,
559
+ # Orthogonalize the second axis
560
+ seconday_axis = F.normalize(
561
+ seconday_axis
562
+ - (seconday_axis * main_axis).sum(-1, keepdim=True) * main_axis,
563
+ eps=1e-6,
564
+ dim=-1,
565
+ ) # 3,
566
+ # Create perpendicular third axis
567
+ third_axis = F.normalize(
568
+ torch.cross(main_axis, seconday_axis, dim=-1), dim=-1, eps=1e-6
569
+ ) # 3,
570
+
571
+ # Check to which canonical axis each aligns
572
+ main_axis_max_idx = main_axis.abs().argmax().item()
573
+ seconday_axis_max_idx = seconday_axis.abs().argmax().item()
574
+ third_axis_max_idx = third_axis.abs().argmax().item()
575
+
576
+ # Now sort the axes based on the argmax so they align with thecanonoical axes
577
+ # If two axes have the same argmax move one of them
578
+ all_possible_axis = {0, 1, 2}
579
+ cur_index = 1
580
+ while (
581
+ len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]))
582
+ != 3
583
+ ):
584
+ # Find missing axis
585
+ missing_axis = all_possible_axis - set(
586
+ [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
587
+ )
588
+ missing_axis = missing_axis.pop()
589
+ # Just assign it to third axis as it had the smallest contribution to the
590
+ # overall shape
591
+ if cur_index == 1:
592
+ third_axis_max_idx = missing_axis
593
+ elif cur_index == 2:
594
+ seconday_axis_max_idx = missing_axis
595
+ else:
596
+ raise ValueError("Could not find 3 unique axis")
597
+ cur_index += 1
598
+
599
+ if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
600
+ raise ValueError("Could not find 3 unique axis")
601
+
602
+ axes = [None] * 3
603
+ axes[main_axis_max_idx] = main_axis
604
+ axes[seconday_axis_max_idx] = seconday_axis
605
+ axes[third_axis_max_idx] = third_axis
606
+ # Create rotation matrix from the individual axes
607
+ rot_mat = torch.stack(axes, dim=1).T
608
+
609
+ # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
610
+ vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
611
+ vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
612
+
613
+ return vertex_positions, vertex_normals
614
+
615
+ def forward(
616
+ self,
617
+ vertex_positions: Tensor,
618
+ vertex_normals: Tensor,
619
+ triangle_idxs: Tensor,
620
+ island_padding: float,
621
+ ) -> Tuple[Tensor, Tensor]:
622
+ """
623
+ Unwrap the mesh
624
+
625
+ Args:
626
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
627
+ vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
628
+ triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
629
+ island_padding (float): Island padding
630
+
631
+ Returns:
632
+ Float[Tensor, "Utex 3"]: Unique UV coordinates
633
+ Integer[Tensor, "Nf"]: Vertex index
634
+ """
635
+ vertex_positions, vertex_normals = self._align_mesh_with_main_axis(
636
+ vertex_positions, vertex_normals
637
+ )
638
+ bbox = torch.stack(
639
+ [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values],
640
+ dim=0,
641
+ ) # 2, 3
642
+
643
+ face_uv, face_index = self._box_assign_vertex_to_cube_face(
644
+ vertex_positions, vertex_normals, triangle_idxs, bbox
645
+ )
646
+
647
+ face_uv = self._rotate_uv_slices_consistent_space(
648
+ vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
649
+ )
650
+
651
+ assigned_atlas_index = self._assign_faces_uv_to_atlas_index(
652
+ vertex_positions, triangle_idxs, face_uv, face_index
653
+ )
654
+
655
+ offset_x, offset_y, div_x, div_y = self._find_slice_offset_and_scale(
656
+ assigned_atlas_index
657
+ )
658
+
659
+ placed_uv = self._distribute_individual_uvs_in_atlas(
660
+ face_uv,
661
+ assigned_atlas_index,
662
+ offset_x,
663
+ offset_y,
664
+ div_x,
665
+ div_y,
666
+ island_padding,
667
+ )
668
+
669
+ return self._get_unique_face_uv(placed_uv)