sharath88 commited on
Commit
ecc6931
·
1 Parent(s): 5edee94

Add semantic agreement mode to NLI analysis UI

Browse files

Introduces a semantic agreement mode alongside strict NLI mode for comparing texts. Updates backend to provide both modes, refactors score aggregation and labeling logic, and enhances frontend with mode switching and improved result display.

Files changed (3) hide show
  1. main.py +67 -42
  2. static/app.js +69 -57
  3. templates/index.html +19 -18
main.py CHANGED
@@ -1,5 +1,4 @@
1
- from typing import Dict, List
2
-
3
  from fastapi import FastAPI, Request
4
  from fastapi.responses import HTMLResponse, JSONResponse
5
  from fastapi.staticfiles import StaticFiles
@@ -7,10 +6,9 @@ from fastapi.templating import Jinja2Templates
7
  from pydantic import BaseModel
8
  from transformers import pipeline
9
 
10
- # ---------------- FastAPI setup ----------------
11
 
12
  app = FastAPI()
13
-
14
  app.mount("/static", StaticFiles(directory="static"), name="static")
15
  templates = Jinja2Templates(directory="templates")
16
 
@@ -25,7 +23,7 @@ async def home(request: Request):
25
  return templates.TemplateResponse("index.html", {"request": request})
26
 
27
 
28
- # ---------------- NLI model setup ----------------
29
 
30
  nli_pipeline = pipeline(
31
  "text-classification",
@@ -34,19 +32,17 @@ nli_pipeline = pipeline(
34
  )
35
 
36
 
37
- def get_direction_scores(premise: str, hypothesis: str) -> Dict[str, float]:
38
  """
39
- Run NLI in one direction: premise -> hypothesis.
40
- Returns normalized scores for ENTAILMENT, CONTRADICTION, NEUTRAL.
41
  """
42
  outputs = nli_pipeline(
43
  {"text": premise, "text_pair": hypothesis},
44
- return_all_scores=True,
45
  )[0]
46
 
47
  scores = {o["label"].upper(): float(o["score"]) for o in outputs}
48
 
49
- # Normalize keys to standard names
50
  return {
51
  "ENTAILMENT": scores.get("ENTAILMENT", scores.get("LABEL_2", 0.0)),
52
  "CONTRADICTION": scores.get("CONTRADICTION", scores.get("LABEL_0", 0.0)),
@@ -54,45 +50,48 @@ def get_direction_scores(premise: str, hypothesis: str) -> Dict[str, float]:
54
  }
55
 
56
 
57
- def aggregate_nli(text_a: str, text_b: str) -> Dict:
58
- """
59
- Compute NLI in both directions A->B and B->A and aggregate scores.
60
- """
61
- scores_ab = get_direction_scores(text_a, text_b)
62
- scores_ba = get_direction_scores(text_b, text_a)
63
-
64
- agg = {
65
- "ENTAILMENT": (scores_ab["ENTAILMENT"] + scores_ba["ENTAILMENT"]) / 2.0,
66
- "CONTRADICTION": (scores_ab["CONTRADICTION"] + scores_ba["CONTRADICTION"]) / 2.0,
67
- "NEUTRAL": (scores_ab["NEUTRAL"] + scores_ba["NEUTRAL"]) / 2.0,
68
  }
69
 
70
- # Decide label
 
 
71
  label = max(agg, key=agg.get)
72
  confidence = agg[label]
73
 
74
- # Simple interpretation
75
- if label == "ENTAILMENT":
76
- explanation = "The two texts are largely consistent / in agreement."
77
- elif label == "CONTRADICTION":
78
- explanation = "The two texts appear to contradict each other."
79
- else:
80
- explanation = "The relationship is mostly neutral or only partially related."
81
-
82
- return {
83
- "overall_label": label,
84
- "overall_confidence": confidence,
85
- "explanation": explanation,
86
- "aggregated_scores": agg,
87
- "directional": {
88
- "A_to_B": scores_ab,
89
- "B_to_A": scores_ba,
90
- },
91
  }
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  @app.post("/analyze")
95
- async def analyze_nli(payload: NLIRequest):
96
  text_a = payload.text_a.strip()
97
  text_b = payload.text_b.strip()
98
 
@@ -103,8 +102,34 @@ async def analyze_nli(payload: NLIRequest):
103
  )
104
 
105
  try:
106
- result = aggregate_nli(text_a, text_b)
107
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
  return JSONResponse(
110
  {"error": f"Error running NLI model: {e}"},
 
1
+ from typing import Dict
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.responses import HTMLResponse, JSONResponse
4
  from fastapi.staticfiles import StaticFiles
 
6
  from pydantic import BaseModel
7
  from transformers import pipeline
8
 
9
+ # ---------------- FastAPI Setup ----------------
10
 
11
  app = FastAPI()
 
12
  app.mount("/static", StaticFiles(directory="static"), name="static")
13
  templates = Jinja2Templates(directory="templates")
14
 
 
23
  return templates.TemplateResponse("index.html", {"request": request})
24
 
25
 
26
+ # ---------------- NLI Model ----------------
27
 
28
  nli_pipeline = pipeline(
29
  "text-classification",
 
32
  )
33
 
34
 
35
+ def nli_scores(premise: str, hypothesis: str) -> Dict[str, float]:
36
  """
37
+ Compute NLI scores in one direction.
 
38
  """
39
  outputs = nli_pipeline(
40
  {"text": premise, "text_pair": hypothesis},
41
+ return_all_scores=True
42
  )[0]
43
 
44
  scores = {o["label"].upper(): float(o["score"]) for o in outputs}
45
 
 
46
  return {
47
  "ENTAILMENT": scores.get("ENTAILMENT", scores.get("LABEL_2", 0.0)),
48
  "CONTRADICTION": scores.get("CONTRADICTION", scores.get("LABEL_0", 0.0)),
 
50
  }
51
 
52
 
53
+ def aggregate_scores(a: Dict[str, float], b: Dict[str, float]) -> Dict[str, float]:
54
+ """Average A->B and B->A scores"""
55
+ return {
56
+ "ENTAILMENT": (a["ENTAILMENT"] + b["ENTAILMENT"]) / 2,
57
+ "CONTRADICTION": (a["CONTRADICTION"] + b["CONTRADICTION"]) / 2,
58
+ "NEUTRAL": (a["NEUTRAL"] + b["NEUTRAL"]) / 2,
 
 
 
 
 
59
  }
60
 
61
+
62
+ def strict_nli_label(agg: Dict[str, float]):
63
+ """Traditional MNLI label"""
64
  label = max(agg, key=agg.get)
65
  confidence = agg[label]
66
 
67
+ explanations = {
68
+ "ENTAILMENT": "The statements logically agree.",
69
+ "CONTRADICTION": "The statements logically contradict each other.",
70
+ "NEUTRAL": "The statements are neither supporting nor contradicting.",
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  }
72
 
73
+ return label, confidence, explanations[label]
74
+
75
+
76
+ def semantic_mode_label(agg: Dict[str, float]):
77
+ """More human-friendly mode"""
78
+
79
+ ent, con, neu = agg["ENTAILMENT"], agg["CONTRADICTION"], agg["NEUTRAL"]
80
+
81
+ if ent > 0.60 and con < 0.25:
82
+ return "Strong Agreement", "The statements strongly support the same meaning."
83
+
84
+ if ent > 0.40 and neu > 0.30:
85
+ return "Partial Agreement", "They share overlap but are not exact entailments."
86
+
87
+ if con > 0.50:
88
+ return "Likely Contradiction", "The statements conflict in their meaning."
89
+
90
+ return "Mixed / Unclear", "No strong relationship detected, or model is uncertain."
91
+
92
 
93
  @app.post("/analyze")
94
+ async def analyze(payload: NLIRequest):
95
  text_a = payload.text_a.strip()
96
  text_b = payload.text_b.strip()
97
 
 
102
  )
103
 
104
  try:
105
+ # Compute directional scores
106
+ a_to_b = nli_scores(text_a, text_b)
107
+ b_to_a = nli_scores(text_b, text_a)
108
+
109
+ # Aggregate
110
+ agg = aggregate_scores(a_to_b, b_to_a)
111
+
112
+ # Strict NLI mode
113
+ strict_label, strict_conf, strict_expl = strict_nli_label(agg)
114
+
115
+ # Semantic mode
116
+ semantic_label, semantic_expl = semantic_mode_label(agg)
117
+
118
+ return {
119
+ "strict_mode": {
120
+ "label": strict_label,
121
+ "confidence": strict_conf,
122
+ "explanation": strict_expl,
123
+ "scores": agg,
124
+ "directional": {"A_to_B": a_to_b, "B_to_A": b_to_a},
125
+ },
126
+ "semantic_mode": {
127
+ "label": semantic_label,
128
+ "explanation": semantic_expl,
129
+ "scores": agg,
130
+ }
131
+ }
132
+
133
  except Exception as e:
134
  return JSONResponse(
135
  {"error": f"Error running NLI model: {e}"},
static/app.js CHANGED
@@ -1,95 +1,107 @@
1
  const textAEl = document.getElementById("textA");
2
  const textBEl = document.getElementById("textB");
3
  const analyzeBtn = document.getElementById("analyzeBtn");
 
4
  const statusEl = document.getElementById("status");
5
- const overallLabelEl = document.getElementById("overallLabel");
6
- const overallExplanationEl = document.getElementById("overallExplanation");
7
  const scoreBoxEl = document.getElementById("scoreBox");
8
 
9
- function formatPercentage(x) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  return (x * 100).toFixed(1) + "%";
11
  }
12
 
13
- async function analyzeRelationship() {
14
- const textA = textAEl.value.trim();
15
- const textB = textBEl.value.trim();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- if (!textA || !textB) {
18
- statusEl.textContent = "Please fill in both Text A and Text B.";
19
  statusEl.className = "status status-warn";
20
  return;
21
  }
22
 
23
- statusEl.textContent = "Analyzing relationship...";
24
  statusEl.className = "status status-info";
25
 
26
- overallLabelEl.textContent = "Computing...";
27
- overallLabelEl.className = "overall-label";
28
- overallExplanationEl.textContent = "";
29
- scoreBoxEl.textContent = "Waiting for scores...";
30
- scoreBoxEl.className = "score-box";
31
 
32
  try {
33
  const res = await fetch("/analyze", {
34
  method: "POST",
35
- headers: { "Content-Type": "application/json" },
36
- body: JSON.stringify({ text_a: textA, text_b: textB })
37
  });
38
 
39
  const data = await res.json();
40
 
41
  if (!res.ok) {
42
- const msg = data.error || "Something went wrong.";
43
- statusEl.textContent = msg;
44
  statusEl.className = "status status-error";
45
- overallLabelEl.textContent = "";
46
- overallExplanationEl.textContent = "";
47
- scoreBoxEl.textContent = "";
48
  return;
49
  }
50
 
51
  statusEl.textContent = "Done!";
52
  statusEl.className = "status status-ok";
 
53
 
54
- const label = data.overall_label || "UNKNOWN";
55
- const confidence = data.overall_confidence || 0;
56
- const explanation = data.explanation || "";
57
-
58
- let emoji = "";
59
- if (label === "ENTAILMENT") emoji = "✅";
60
- else if (label === "CONTRADICTION") emoji = "⚠️";
61
- else if (label === "NEUTRAL") emoji = "🔍";
62
-
63
- overallLabelEl.textContent =
64
- `${emoji} ${label} (confidence: ${formatPercentage(confidence)})`;
65
- overallLabelEl.className = "overall-label";
66
- overallExplanationEl.textContent = explanation;
67
 
68
- const agg = data.aggregated_scores || {};
69
- const ent = agg.ENTAILMENT ?? 0;
70
- const con = agg.CONTRADICTION ?? 0;
71
- const neu = agg.NEUTRAL ?? 0;
72
-
73
- scoreBoxEl.className = "score-box";
74
- scoreBoxEl.innerHTML = `
75
- <div class="score-row">
76
- <span class="score-label">Entailment</span>
77
- <span class="score-value">${formatPercentage(ent)}</span>
78
- </div>
79
- <div class="score-row">
80
- <span class="score-label">Contradiction</span>
81
- <span class="score-value">${formatPercentage(con)}</span>
82
- </div>
83
- <div class="score-row">
84
- <span class="score-label">Neutral</span>
85
- <span class="score-value">${formatPercentage(neu)}</span>
86
- </div>
87
- `;
88
  } catch (err) {
89
- console.error(err);
90
- statusEl.textContent = "Error contacting the server.";
91
  statusEl.className = "status status-error";
92
  }
93
  }
94
 
95
- analyzeBtn.addEventListener("click", analyzeRelationship);
 
1
  const textAEl = document.getElementById("textA");
2
  const textBEl = document.getElementById("textB");
3
  const analyzeBtn = document.getElementById("analyzeBtn");
4
+
5
  const statusEl = document.getElementById("status");
6
+ const resultLabelEl = document.getElementById("resultLabel");
7
+ const resultExplanationEl = document.getElementById("resultExplanation");
8
  const scoreBoxEl = document.getElementById("scoreBox");
9
 
10
+ const strictTab = document.getElementById("strictTab");
11
+ const semanticTab = document.getElementById("semanticTab");
12
+
13
+ let lastResult = null;
14
+ let activeMode = "strict";
15
+
16
+ function switchMode(mode) {
17
+ activeMode = mode;
18
+
19
+ strictTab.classList.toggle("active", mode === "strict");
20
+ semanticTab.classList.toggle("active", mode === "semantic");
21
+
22
+ if (lastResult) updateUI();
23
+ }
24
+
25
+ strictTab.onclick = () => switchMode("strict");
26
+ semanticTab.onclick = () => switchMode("semantic");
27
+
28
+ function formatPct(x) {
29
  return (x * 100).toFixed(1) + "%";
30
  }
31
 
32
+ function updateUI() {
33
+ if (!lastResult) return;
34
+
35
+ if (activeMode === "strict") {
36
+ const d = lastResult.strict_mode;
37
+ resultLabelEl.textContent = `${d.label} (${formatPct(d.confidence)})`;
38
+ resultExplanationEl.textContent = d.explanation;
39
+
40
+ const s = d.scores;
41
+ scoreBoxEl.innerHTML = `
42
+ <div>Entailment: ${formatPct(s.ENTAILMENT)}</div>
43
+ <div>Contradiction: ${formatPct(s.CONTRADICTION)}</div>
44
+ <div>Neutral: ${formatPct(s.NEUTRAL)}</div>
45
+ `;
46
+ }
47
+
48
+ if (activeMode === "semantic") {
49
+ const d = lastResult.semantic_mode;
50
+
51
+ resultLabelEl.textContent = d.label;
52
+ resultExplanationEl.textContent = d.explanation;
53
+
54
+ const s = d.scores;
55
+ scoreBoxEl.innerHTML = `
56
+ <div>Agreement strength (semantic summary)</div>
57
+ <div>Entailment: ${formatPct(s.ENTAILMENT)}</div>
58
+ <div>Contradiction: ${formatPct(s.CONTRADICTION)}</div>
59
+ <div>Neutral: ${formatPct(s.NEUTRAL)}</div>
60
+ `;
61
+ }
62
+ }
63
+
64
+ async function analyzeTexts() {
65
+ const text_a = textAEl.value.trim();
66
+ const text_b = textBEl.value.trim();
67
 
68
+ if (!text_a || !text_b) {
69
+ statusEl.textContent = "Enter both texts.";
70
  statusEl.className = "status status-warn";
71
  return;
72
  }
73
 
74
+ statusEl.textContent = "Analyzing...";
75
  statusEl.className = "status status-info";
76
 
77
+ resultLabelEl.textContent = "Computing...";
78
+ scoreBoxEl.textContent = "";
 
 
 
79
 
80
  try {
81
  const res = await fetch("/analyze", {
82
  method: "POST",
83
+ headers: {"Content-Type": "application/json"},
84
+ body: JSON.stringify({ text_a, text_b })
85
  });
86
 
87
  const data = await res.json();
88
 
89
  if (!res.ok) {
90
+ statusEl.textContent = data.error;
 
91
  statusEl.className = "status status-error";
 
 
 
92
  return;
93
  }
94
 
95
  statusEl.textContent = "Done!";
96
  statusEl.className = "status status-ok";
97
+ lastResult = data;
98
 
99
+ updateUI();
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  } catch (err) {
102
+ statusEl.textContent = "Error contacting server.";
 
103
  statusEl.className = "status status-error";
104
  }
105
  }
106
 
107
+ analyzeBtn.addEventListener("click", analyzeTexts);
templates/index.html CHANGED
@@ -7,52 +7,53 @@
7
  </head>
8
  <body>
9
  <div class="container">
 
10
  <h1>Contradiction & Entailment Demo</h1>
11
  <p class="subtitle">
12
- Compare two sentences or paragraphs to see if they are in agreement, contradict each other, or are mostly neutral.
13
  </p>
14
 
 
15
  <div class="input-section">
16
  <div class="input-card">
17
  <h2>Text A</h2>
18
- <textarea
19
- id="textA"
20
- rows="6"
21
- placeholder="Enter the first sentence or paragraph..."
22
- ></textarea>
23
  </div>
24
 
25
  <div class="input-card">
26
  <h2>Text B</h2>
27
- <textarea
28
- id="textB"
29
- rows="6"
30
- placeholder="Enter the second sentence or paragraph..."
31
- ></textarea>
32
  </div>
33
  </div>
34
 
35
  <div class="controls">
36
- <button id="analyzeBtn">Analyze Relationship</button>
37
  </div>
38
 
39
  <div id="status" class="status"></div>
40
 
 
 
 
 
 
 
 
41
  <div class="results">
 
42
  <div class="result-card">
43
- <h2>Overall Relationship</h2>
44
- <p id="overallLabel" class="overall-label placeholder">
45
- Run an analysis to see the result.
46
- </p>
47
- <p id="overallExplanation" class="overall-explanation"></p>
48
  </div>
49
 
50
  <div class="result-card">
51
  <h2>Score Breakdown</h2>
52
  <div id="scoreBox" class="score-box placeholder">
53
- Entailment, contradiction, and neutral scores will appear here.
54
  </div>
55
  </div>
 
56
  </div>
57
  </div>
58
 
 
7
  </head>
8
  <body>
9
  <div class="container">
10
+
11
  <h1>Contradiction & Entailment Demo</h1>
12
  <p class="subtitle">
13
+ Compare two texts to determine if they agree, contradict, or are neutral.
14
  </p>
15
 
16
+ <!-- Inputs -->
17
  <div class="input-section">
18
  <div class="input-card">
19
  <h2>Text A</h2>
20
+ <textarea id="textA" rows="6"></textarea>
 
 
 
 
21
  </div>
22
 
23
  <div class="input-card">
24
  <h2>Text B</h2>
25
+ <textarea id="textB" rows="6"></textarea>
 
 
 
 
26
  </div>
27
  </div>
28
 
29
  <div class="controls">
30
+ <button id="analyzeBtn">Analyze</button>
31
  </div>
32
 
33
  <div id="status" class="status"></div>
34
 
35
+ <!-- MODE SWITCH -->
36
+ <div class="mode-switch">
37
+ <button id="strictTab" class="active">Strict NLI Mode</button>
38
+ <button id="semanticTab">Semantic Agreement Mode</button>
39
+ </div>
40
+
41
+ <!-- Results -->
42
  <div class="results">
43
+
44
  <div class="result-card">
45
+ <h2>Result</h2>
46
+ <p id="resultLabel" class="overall-label placeholder">Run an analysis.</p>
47
+ <p id="resultExplanation" class="overall-explanation"></p>
 
 
48
  </div>
49
 
50
  <div class="result-card">
51
  <h2>Score Breakdown</h2>
52
  <div id="scoreBox" class="score-box placeholder">
53
+ Scores will appear here.
54
  </div>
55
  </div>
56
+
57
  </div>
58
  </div>
59