Spaces:
Running
Running
| import os, asyncio, httpx, websockets | |
| from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import StreamingResponse, PlainTextResponse | |
| UPSTREAM = os.getenv("UPSTREAM", "https://cjerzak-policyview.hf.space").rstrip("/") | |
| TOKEN = os.getenv("HF_TOKEN") | |
| if not TOKEN: | |
| raise RuntimeError("Set HF_TOKEN as a Secret in the Space settings.") | |
| app = FastAPI() | |
| def healthz(): | |
| return PlainTextResponse("ok") | |
| async def stream_bytes(resp): | |
| async for chunk in resp.aiter_bytes(): | |
| yield chunk | |
| async def proxy(request: Request, path: str): | |
| url = f"{UPSTREAM}/{path}" | |
| if request.url.query: | |
| url += f"?{request.url.query}" | |
| headers = {k: v for k, v in request.headers.items() | |
| if k.lower() not in ("host", "content-length", "authorization")} | |
| headers["Authorization"] = f"Bearer {TOKEN}" | |
| headers["x-forwarded-host"] = request.headers.get("host", "") | |
| headers["x-forwarded-proto"] = request.url.scheme | |
| async with httpx.AsyncClient(follow_redirects=False, timeout=httpx.Timeout(60.0, connect=60.0)) as client: | |
| upstream = await client.request( | |
| request.method, url, headers=headers, content=await request.body(), stream=True | |
| ) | |
| # Strip hop-by-hop headers and rewrite redirects back through this proxy | |
| drop = {"content-length","transfer-encoding","connection","keep-alive", | |
| "proxy-authenticate","proxy-authorization","te","trailers","upgrade","set-cookie"} | |
| out_headers = {k: v for k, v in upstream.headers.items() if k.lower() not in drop} | |
| loc = upstream.headers.get("location") | |
| if loc and loc.startswith(UPSTREAM): | |
| out_headers["location"] = loc.replace(UPSTREAM, "") | |
| return StreamingResponse(stream_bytes(upstream), status_code=upstream.status_code, headers=out_headers) | |
| async def ws_proxy(ws: WebSocket, path: str): | |
| await ws.accept() | |
| target = f"{UPSTREAM}/{path}" | |
| if ws.url.query: | |
| target += f"?{ws.url.query}" | |
| ws_headers = [("Authorization", f"Bearer {TOKEN}")] | |
| # Upgrade to wss for https upstream | |
| target = target.replace("https://", "wss://").replace("http://", "ws://") | |
| try: | |
| async with websockets.connect(target, extra_headers=ws_headers, origin=UPSTREAM) as ups: | |
| async def client_to_up(): | |
| while True: | |
| msg = await ws.receive() | |
| data = msg.get("text") if "text" in msg else msg.get("bytes") | |
| if data is None: break | |
| await ups.send(data) | |
| async def up_to_client(): | |
| while True: | |
| data = await ups.recv() | |
| if isinstance(data, (bytes, bytearray)): | |
| await ws.send_bytes(data) | |
| else: | |
| await ws.send_text(data) | |
| await asyncio.gather(client_to_up(), up_to_client()) | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception: | |
| await ws.close() | |