udbhav commited on
Commit
8c96afb
·
1 Parent(s): ce31518

Add CPU fallback for segment_csr to run on HF Space

Browse files
Files changed (1) hide show
  1. models/ansysLPFMs.py +1 -42
models/ansysLPFMs.py CHANGED
@@ -8,48 +8,7 @@ import einops
8
  import torch
9
  from torch import nn
10
  from torch_geometric.nn.pool import radius_graph
11
- # from torch_scatter import segment_csr
12
- # Try to use torch_scatter if available (cluster / GPU env),
13
- # otherwise fall back to a pure PyTorch CPU implementation.
14
- try:
15
- from torch_scatter import segment_csr as _segment_csr
16
- HAS_TORCH_SCATTER = True
17
- except Exception:
18
- HAS_TORCH_SCATTER = False
19
-
20
- def _segment_csr(src: torch.Tensor, indptr: torch.Tensor, reduce: str = "sum") -> torch.Tensor:
21
- """
22
- CPU-only fallback for torch_scatter.segment_csr.
23
-
24
- Parameters
25
- ----------
26
- src : Tensor
27
- Shape [N, ...], values to be reduced.
28
- indptr : Tensor
29
- Shape [num_segments + 1], CSR-style pointers.
30
- reduce : str
31
- Only 'sum' is implemented here.
32
-
33
- Returns
34
- -------
35
- out : Tensor
36
- Shape [num_segments, ...], segment-wise reductions.
37
- """
38
- if reduce != "sum":
39
- raise NotImplementedError("Fallback segment_csr currently supports only reduce='sum'")
40
-
41
- # Number of segments
42
- num_segments = indptr.numel() - 1
43
- out_shape = (num_segments,) + src.shape[1:]
44
- out = src.new_zeros(out_shape)
45
-
46
- # Simple loop over segments (fine for CPU inference in HF Space)
47
- for i in range(num_segments):
48
- start = int(indptr[i].item())
49
- end = int(indptr[i + 1].item())
50
- if start < end:
51
- out[i] = src[start:end].sum(dim=0)
52
- return out
53
  import torch.nn.functional as F
54
 
55
 
 
8
  import torch
9
  from torch import nn
10
  from torch_geometric.nn.pool import radius_graph
11
+ from torch_scatter import segment_csr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import torch.nn.functional as F
13
 
14