Spaces:
Sleeping
Sleeping
udbhav
commited on
Commit
·
8c96afb
1
Parent(s):
ce31518
Add CPU fallback for segment_csr to run on HF Space
Browse files- models/ansysLPFMs.py +1 -42
models/ansysLPFMs.py
CHANGED
|
@@ -8,48 +8,7 @@ import einops
|
|
| 8 |
import torch
|
| 9 |
from torch import nn
|
| 10 |
from torch_geometric.nn.pool import radius_graph
|
| 11 |
-
|
| 12 |
-
# Try to use torch_scatter if available (cluster / GPU env),
|
| 13 |
-
# otherwise fall back to a pure PyTorch CPU implementation.
|
| 14 |
-
try:
|
| 15 |
-
from torch_scatter import segment_csr as _segment_csr
|
| 16 |
-
HAS_TORCH_SCATTER = True
|
| 17 |
-
except Exception:
|
| 18 |
-
HAS_TORCH_SCATTER = False
|
| 19 |
-
|
| 20 |
-
def _segment_csr(src: torch.Tensor, indptr: torch.Tensor, reduce: str = "sum") -> torch.Tensor:
|
| 21 |
-
"""
|
| 22 |
-
CPU-only fallback for torch_scatter.segment_csr.
|
| 23 |
-
|
| 24 |
-
Parameters
|
| 25 |
-
----------
|
| 26 |
-
src : Tensor
|
| 27 |
-
Shape [N, ...], values to be reduced.
|
| 28 |
-
indptr : Tensor
|
| 29 |
-
Shape [num_segments + 1], CSR-style pointers.
|
| 30 |
-
reduce : str
|
| 31 |
-
Only 'sum' is implemented here.
|
| 32 |
-
|
| 33 |
-
Returns
|
| 34 |
-
-------
|
| 35 |
-
out : Tensor
|
| 36 |
-
Shape [num_segments, ...], segment-wise reductions.
|
| 37 |
-
"""
|
| 38 |
-
if reduce != "sum":
|
| 39 |
-
raise NotImplementedError("Fallback segment_csr currently supports only reduce='sum'")
|
| 40 |
-
|
| 41 |
-
# Number of segments
|
| 42 |
-
num_segments = indptr.numel() - 1
|
| 43 |
-
out_shape = (num_segments,) + src.shape[1:]
|
| 44 |
-
out = src.new_zeros(out_shape)
|
| 45 |
-
|
| 46 |
-
# Simple loop over segments (fine for CPU inference in HF Space)
|
| 47 |
-
for i in range(num_segments):
|
| 48 |
-
start = int(indptr[i].item())
|
| 49 |
-
end = int(indptr[i + 1].item())
|
| 50 |
-
if start < end:
|
| 51 |
-
out[i] = src[start:end].sum(dim=0)
|
| 52 |
-
return out
|
| 53 |
import torch.nn.functional as F
|
| 54 |
|
| 55 |
|
|
|
|
| 8 |
import torch
|
| 9 |
from torch import nn
|
| 10 |
from torch_geometric.nn.pool import radius_graph
|
| 11 |
+
from torch_scatter import segment_csr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import torch.nn.functional as F
|
| 13 |
|
| 14 |
|