chore(repo): initialize git with .gitignore, .gitattributes, and project sources

This commit is contained in:
S
2025-10-26 08:56:41 -04:00
parent 43640e7239
commit 95979d004e
22 changed files with 4935 additions and 0 deletions

1
src/__init__.py Normal file
View File

@@ -0,0 +1 @@
# This file is intentionally left blank.

1313
src/analyze_csv.py Normal file

File diff suppressed because it is too large Load Diff

50
src/apply_labels.py Normal file
View File

@@ -0,0 +1,50 @@
import argparse
import os
import pandas as pd
def read_csv(path: str) -> pd.DataFrame:
if not os.path.exists(path):
raise SystemExit(f"CSV not found: {path}")
return pd.read_csv(path)
def main():
p = argparse.ArgumentParser(description='Apply labeled sentiments to posts/replies CSVs for analysis plots.')
p.add_argument('--labeled-csv', required=True, help='Path to labeled_sentiment.csv (must include id and label columns)')
p.add_argument('--posts-csv', required=True, help='Original posts CSV')
p.add_argument('--replies-csv', required=True, help='Original replies CSV')
p.add_argument('--posts-out', default=None, help='Output posts CSV path (default: <posts> with _with_labels suffix)')
p.add_argument('--replies-out', default=None, help='Output replies CSV path (default: <replies> with _with_labels suffix)')
args = p.parse_args()
labeled = read_csv(args.labeled_csv)
if 'id' not in labeled.columns:
raise SystemExit('labeled CSV must include an id column to merge on')
# normalize label column name to sentiment_label
lab_col = 'label' if 'label' in labeled.columns else ('sentiment_label' if 'sentiment_label' in labeled.columns else None)
if lab_col is None:
raise SystemExit("labeled CSV must include a 'label' or 'sentiment_label' column")
labeled = labeled[['id', lab_col] + (['confidence'] if 'confidence' in labeled.columns else [])].copy()
labeled = labeled.rename(columns={lab_col: 'sentiment_label'})
posts = read_csv(args.posts_csv)
replies = read_csv(args.replies_csv)
if 'id' not in posts.columns or 'id' not in replies.columns:
raise SystemExit('posts/replies CSVs must include id columns')
posts_out = args.posts_out or os.path.splitext(args.posts_csv)[0] + '_with_labels.csv'
replies_out = args.replies_out or os.path.splitext(args.replies_csv)[0] + '_with_labels.csv'
posts_merged = posts.merge(labeled, how='left', on='id', validate='m:1')
replies_merged = replies.merge(labeled, how='left', on='id', validate='m:1')
posts_merged.to_csv(posts_out, index=False)
replies_merged.to_csv(replies_out, index=False)
print(f"Wrote posts with labels -> {posts_out} (rows={len(posts_merged)})")
print(f"Wrote replies with labels -> {replies_out} (rows={len(replies_merged)})")
if __name__ == '__main__':
main()

107
src/audit_team_sentiment.py Normal file
View File

@@ -0,0 +1,107 @@
import argparse
import os
from typing import List
import pandas as pd
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
def parse_tags_column(series: pd.Series) -> pd.Series:
def _to_list(x):
if isinstance(x, list):
return x
if pd.isna(x):
return []
s = str(x)
# Expect semicolon-delimited from augmented CSV, but also accept comma
if ';' in s:
return [t.strip() for t in s.split(';') if t.strip()]
if ',' in s:
return [t.strip() for t in s.split(',') if t.strip()]
return [s] if s else []
return series.apply(_to_list)
def main():
parser = argparse.ArgumentParser(description='Audit sentiment per team tag and export samples for inspection.')
parser.add_argument('--csv', default='data/premier_league_update_tagged.csv', help='Tagged posts CSV (augmented by analyze)')
parser.add_argument('--team', default='club_manchester_united', help='Team tag to export samples for (e.g., club_manchester_united)')
parser.add_argument('--out-dir', default='data', help='Directory to write audit outputs')
parser.add_argument('--samples', type=int, default=25, help='Number of samples to export for the specified team')
parser.add_argument('--with-vader', action='store_true', help='Also compute VADER-based sentiment shares as a sanity check')
args = parser.parse_args()
if not os.path.exists(args.csv):
raise SystemExit(f"CSV not found: {args.csv}. Run analyze with --write-augmented-csv first.")
df = pd.read_csv(args.csv)
if 'message' not in df.columns:
raise SystemExit('CSV missing message column')
if 'sentiment_compound' not in df.columns:
raise SystemExit('CSV missing sentiment_compound column')
if 'tags' not in df.columns:
raise SystemExit('CSV missing tags column')
df = df.copy()
df['tags'] = parse_tags_column(df['tags'])
# Filter to team tags (prefix club_)
e = df.explode('tags')
e = e[e['tags'].notna() & (e['tags'] != '')]
e = e[e['tags'].astype(str).str.startswith('club_')]
e = e.dropna(subset=['sentiment_compound'])
if e.empty:
print('No team-tagged rows found.')
return
# Shares
e = e.copy()
e['is_pos'] = e['sentiment_compound'] > 0.05
e['is_neg'] = e['sentiment_compound'] < -0.05
grp = (
e.groupby('tags')
.agg(
n=('sentiment_compound', 'count'),
mean=('sentiment_compound', 'mean'),
median=('sentiment_compound', 'median'),
pos_share=('is_pos', 'mean'),
neg_share=('is_neg', 'mean'),
)
.reset_index()
)
grp['neu_share'] = (1 - grp['pos_share'] - grp['neg_share']).clip(lower=0)
grp = grp.sort_values(['n', 'mean'], ascending=[False, False])
if args.with_vader:
# Compute VADER shares on the underlying messages per team
analyzer = SentimentIntensityAnalyzer()
def _vader_sentiment_share(sub: pd.DataFrame):
if sub.empty:
return pd.Series({'pos_share_vader': 0.0, 'neg_share_vader': 0.0, 'neu_share_vader': 0.0})
scores = sub['message'].astype(str).apply(lambda t: analyzer.polarity_scores(t or '')['compound'])
pos = (scores > 0.05).mean()
neg = (scores < -0.05).mean()
neu = max(0.0, 1.0 - pos - neg)
return pd.Series({'pos_share_vader': pos, 'neg_share_vader': neg, 'neu_share_vader': neu})
vader_grp = e.groupby('tags').apply(_vader_sentiment_share).reset_index()
grp = grp.merge(vader_grp, on='tags', how='left')
os.makedirs(args.out_dir, exist_ok=True)
out_summary = os.path.join(args.out_dir, 'team_sentiment_audit.csv')
grp.to_csv(out_summary, index=False)
print(f"Wrote summary: {out_summary}")
# Export samples for selected team
te = e[e['tags'] == args.team].copy()
if te.empty:
print(f"No rows for team tag: {args.team}")
return
# Sort by sentiment descending to inspect highly positive claims
te = te.sort_values('sentiment_compound', ascending=False)
cols = [c for c in ['id', 'date', 'message', 'sentiment_compound', 'url'] if c in te.columns]
samples_path = os.path.join(args.out_dir, f"{args.team}_samples.csv")
te[cols].head(args.samples).to_csv(samples_path, index=False)
print(f"Wrote samples: {samples_path} ({min(args.samples, len(te))} rows)")
if __name__ == '__main__':
main()

218
src/auto_label_sentiment.py Normal file
View File

@@ -0,0 +1,218 @@
import argparse
import os
from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
try:
# Allow both package and direct script execution
from .make_labeling_set import load_messages as _load_messages
except Exception:
from make_labeling_set import load_messages as _load_messages
def _combine_inputs(posts_csv: Optional[str], replies_csv: Optional[str], text_col: str = 'message', min_length: int = 3) -> pd.DataFrame:
frames: List[pd.DataFrame] = []
if posts_csv:
frames.append(_load_messages(posts_csv, text_col=text_col))
if replies_csv:
# include parent_id if present for replies
frames.append(_load_messages(replies_csv, text_col=text_col, extra_cols=['parent_id']))
if not frames:
raise SystemExit('No input provided. Use --input-csv or --posts-csv/--replies-csv')
df = pd.concat(frames, ignore_index=True)
df['message'] = df['message'].fillna('').astype(str)
df = df[df['message'].str.len() >= min_length]
df = df.drop_duplicates(subset=['message']).reset_index(drop=True)
return df
def _map_label_str_to_int(labels: List[str]) -> List[int]:
mapping = {'neg': 0, 'negative': 0, 'neu': 1, 'neutral': 1, 'pos': 2, 'positive': 2}
out: List[int] = []
for lab in labels:
lab_l = (lab or '').lower()
if lab_l in mapping:
out.append(mapping[lab_l])
else:
# fallback: try to parse integer
try:
out.append(int(lab))
except Exception:
out.append(1) # default to neutral
return out
def _vader_label(compound: float, pos_th: float, neg_th: float) -> str:
if compound >= pos_th:
return 'pos'
if compound <= neg_th:
return 'neg'
return 'neu'
def _auto_label_vader(texts: List[str], pos_th: float, neg_th: float, min_margin: float) -> Tuple[List[str], List[float]]:
analyzer = SentimentIntensityAnalyzer()
labels: List[str] = []
confs: List[float] = []
for t in texts:
s = analyzer.polarity_scores(t or '')
comp = float(s.get('compound', 0.0))
lab = _vader_label(comp, pos_th, neg_th)
# Confidence heuristic: distance from neutral band edges
if lab == 'pos':
conf = max(0.0, comp - pos_th)
elif lab == 'neg':
conf = max(0.0, abs(comp - neg_th))
else:
# closer to 0 is more neutral; confidence inversely related to |compound|
conf = max(0.0, (pos_th - abs(comp)))
labels.append(lab)
confs.append(conf)
# Normalize confidence roughly to [0,1] by clipping with a reasonable scale
confs = [min(1.0, c / max(1e-6, min_margin)) for c in confs]
return labels, confs
def _auto_label_transformers(texts: List[str], model_name_or_path: str, batch_size: int, min_prob: float, min_margin: float) -> Tuple[List[str], List[float]]:
try:
from .transformer_sentiment import TransformerSentiment
except Exception:
from transformer_sentiment import TransformerSentiment
clf = TransformerSentiment(model_name_or_path)
probs_all, labels_all = clf.predict_probs_and_labels(texts, batch_size=batch_size)
confs: List[float] = []
for row in probs_all:
row = np.array(row, dtype=float)
if row.size == 0:
confs.append(0.0)
continue
top2 = np.sort(row)[-2:] if row.size >= 2 else np.array([0.0, row.max()])
max_p = float(row.max())
margin = float(top2[-1] - top2[-2]) if row.size >= 2 else max_p
# Confidence must satisfy both conditions
conf = min(max(0.0, (max_p - min_prob) / max(1e-6, 1 - min_prob)), max(0.0, margin / max(1e-6, min_margin)))
confs.append(conf)
# Map arbitrary id2label names to canonical 'neg/neu/pos' when obvious; else keep as-is
canonical = []
for lab in labels_all:
ll = (lab or '').lower()
if 'neg' in ll:
canonical.append('neg')
elif 'neu' in ll or 'neutral' in ll:
canonical.append('neu')
elif 'pos' in ll or 'positive' in ll:
canonical.append('pos')
else:
canonical.append(lab)
return canonical, confs
def main():
parser = argparse.ArgumentParser(description='Automatically label sentiment without manual annotation.')
src = parser.add_mutually_exclusive_group(required=True)
src.add_argument('--input-csv', help='Single CSV containing a text column (default: message)')
src.add_argument('--posts-csv', help='Posts CSV to include')
parser.add_argument('--replies-csv', help='Replies CSV to include (combined with posts if provided)')
parser.add_argument('--text-col', default='message', help='Text column name in input CSV(s)')
parser.add_argument('-o', '--output', default='data/labeled_sentiment.csv', help='Output labeled CSV path')
parser.add_argument('--limit', type=int, default=None, help='Optional cap on number of rows')
parser.add_argument('--min-length', type=int, default=3, help='Minimum text length to consider')
parser.add_argument('--backend', choices=['vader', 'transformers', 'gpt'], default='vader', help='Labeling backend: vader, transformers, or gpt (local via Ollama)')
# VADER knobs
parser.add_argument('--vader-pos', type=float, default=0.05, help='VADER positive threshold (compound >=)')
parser.add_argument('--vader-neg', type=float, default=-0.05, help='VADER negative threshold (compound <=)')
parser.add_argument('--vader-margin', type=float, default=0.2, help='Confidence scaling for VADER distance')
# Transformers knobs
parser.add_argument('--transformers-model', default='cardiffnlp/twitter-roberta-base-sentiment-latest', help='HF model for 3-class sentiment')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--min-prob', type=float, default=0.6, help='Min top class probability to accept')
parser.add_argument('--min-margin', type=float, default=0.2, help='Min prob gap between top-1 and top-2 to accept')
# GPT knobs
parser.add_argument('--gpt-model', default='llama3', help='Local GPT model name (Ollama)')
parser.add_argument('--gpt-base-url', default='http://localhost:11434', help='Base URL for local GPT server (Ollama)')
parser.add_argument('--gpt-batch-size', type=int, default=8)
parser.add_argument('--label-format', choices=['str', 'int'], default='str', help="Output labels as strings ('neg/neu/pos') or integers (0/1/2)")
parser.add_argument('--only-confident', action='store_true', help='Drop rows that do not meet confidence thresholds')
args = parser.parse_args()
# Load inputs
if args.input_csv:
if not os.path.exists(args.input_csv):
raise SystemExit(f"Input CSV not found: {args.input_csv}")
df = pd.read_csv(args.input_csv)
if args.text_col not in df.columns:
raise SystemExit(f"Text column '{args.text_col}' not in {args.input_csv}")
df = df.copy()
df['message'] = df[args.text_col].astype(str)
base_cols = [c for c in ['id', 'date', 'message', 'url'] if c in df.columns]
df = df[base_cols if base_cols else ['message']]
df = df[df['message'].str.len() >= args.min_length]
df = df.drop_duplicates(subset=['message']).reset_index(drop=True)
else:
df = _combine_inputs(args.posts_csv, args.replies_csv, text_col=args.text_col, min_length=args.min_length)
if args.limit and len(df) > args.limit:
df = df.head(args.limit)
texts = df['message'].astype(str).tolist()
# Predict labels + confidence
if args.backend == 'vader':
labels, conf = _auto_label_vader(texts, pos_th=args.vader_pos, neg_th=args.vader_neg, min_margin=args.vader_margin)
# For VADER, define acceptance: confident if outside neutral band by at least margin, or inside band with closeness to 0 below threshold
accept = []
analyzer = SentimentIntensityAnalyzer()
for t in texts:
comp = analyzer.polarity_scores(t or '').get('compound')
if comp is None:
accept.append(False)
continue
comp = float(comp)
if comp >= args.vader_pos + args.vader_margin or comp <= args.vader_neg - args.vader_margin:
accept.append(True)
else:
# inside or near band -> consider less confident
accept.append(False)
elif args.backend == 'transformers':
labels, conf = _auto_label_transformers(texts, args.transformers_model, args.batch_size, args.min_prob, args.min_margin)
accept = [((c >= 1.0)) or ((c >= 0.5)) for c in conf] # normalize conf ~[0,1]; accept medium-high confidence
else:
# GPT backend via Ollama: expect label+confidence
try:
from .gpt_sentiment import GPTSentiment
except Exception:
from gpt_sentiment import GPTSentiment
clf = GPTSentiment(base_url=args.gpt_base_url, model=args.gpt_model)
labels, conf = clf.predict_label_conf_batch(texts, batch_size=args.gpt_batch_size)
# Accept medium-high confidence; simple threshold like transformers path
accept = [c >= 0.5 for c in conf]
out = df.copy()
out.insert(1, 'label', labels)
out['confidence'] = conf
if args.only_confident:
out = out[np.array(accept, dtype=bool)]
out = out.reset_index(drop=True)
if args.label_format == 'int':
out['label'] = _map_label_str_to_int(out['label'].astype(str).tolist())
os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True)
out.to_csv(args.output, index=False)
kept = len(out)
print(f"Wrote {kept} labeled rows to {args.output} using backend={args.backend}")
if args.only_confident:
print("Note: only confident predictions were kept. You can remove --only-confident to include all rows.")
if __name__ == '__main__':
main()

48
src/eval_sentiment.py Normal file
View File

@@ -0,0 +1,48 @@
import argparse
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report
try:
from .transformer_sentiment import TransformerSentiment
except ImportError:
# Allow running as a script via -m src.eval_sentiment
from transformer_sentiment import TransformerSentiment
def main():
parser = argparse.ArgumentParser(description='Evaluate a fine-tuned transformers sentiment model on a labeled CSV')
parser.add_argument('--csv', required=True, help='Labeled CSV path with message and label columns')
parser.add_argument('--text-col', default='message')
parser.add_argument('--label-col', default='label')
parser.add_argument('--model', required=True, help='Model name or local path')
parser.add_argument('--batch-size', type=int, default=64)
args = parser.parse_args()
df = pd.read_csv(args.csv)
df = df[[args.text_col, args.label_col]].dropna().copy()
texts = df[args.text_col].astype(str).tolist()
true_labels = df[args.label_col].astype(str).tolist()
clf = TransformerSentiment(args.model)
_, pred_labels = clf.predict_probs_and_labels(texts, batch_size=args.batch_size)
y_true = np.array(true_labels)
y_pred = np.array(pred_labels)
# If labels differ from model id2label names, normalize to strings for comparison
acc = accuracy_score(y_true, y_pred)
f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
prec_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
rec_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
print('Accuracy:', f"{acc:.4f}")
print('F1 (macro):', f"{f1_macro:.4f}")
print('Precision (macro):', f"{prec_macro:.4f}")
print('Recall (macro):', f"{rec_macro:.4f}")
print('\nClassification report:')
print(classification_report(y_true, y_pred, zero_division=0))
if __name__ == '__main__':
main()

131
src/fetch_schedule.py Normal file
View File

@@ -0,0 +1,131 @@
import argparse
import csv
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
import requests
from dotenv import load_dotenv
API_BASE = "https://api.football-data.org/v4"
COMPETITION_CODE = "PL" # Premier League
def iso_date(d: str) -> str:
# Accept YYYY-MM-DD and return ISO date
try:
return datetime.fromisoformat(d).date().isoformat()
except Exception as e:
raise argparse.ArgumentTypeError(f"Invalid date: {d}. Use YYYY-MM-DD") from e
def fetch_matches(start_date: str, end_date: str, token: str) -> Dict[str, Any]:
url = f"{API_BASE}/competitions/{COMPETITION_CODE}/matches"
headers = {"X-Auth-Token": token}
params = {
"dateFrom": start_date,
"dateTo": end_date,
}
r = requests.get(url, headers=headers, params=params, timeout=30)
r.raise_for_status()
return r.json()
def normalize_match(m: Dict[str, Any]) -> Dict[str, Any]:
utc_date = m.get("utcDate")
# Convert to date/time strings
kick_iso = None
if utc_date:
try:
kick_iso = datetime.fromisoformat(utc_date.replace("Z", "+00:00")).isoformat()
except Exception:
kick_iso = utc_date
score = m.get("score", {})
full_time = score.get("fullTime", {})
return {
"id": m.get("id"),
"status": m.get("status"),
"matchday": m.get("matchday"),
"utcDate": kick_iso,
"homeTeam": (m.get("homeTeam") or {}).get("name"),
"awayTeam": (m.get("awayTeam") or {}).get("name"),
"homeScore": full_time.get("home"),
"awayScore": full_time.get("away"),
"referees": ", ".join([r.get("name", "") for r in m.get("referees", []) if r.get("name")]),
"venue": m.get("area", {}).get("name"),
"competition": (m.get("competition") or {}).get("name"),
"stage": m.get("stage"),
"group": m.get("group"),
"link": m.get("id") and f"https://www.football-data.org/match/{m['id']}" or None,
}
def save_csv(matches: List[Dict[str, Any]], out_path: str) -> None:
if not matches:
# Write header only
fields = [
"id",
"status",
"matchday",
"utcDate",
"homeTeam",
"awayTeam",
"homeScore",
"awayScore",
"referees",
"venue",
"competition",
"stage",
"group",
"link",
]
with open(out_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
return
fields = list(matches[0].keys())
with open(out_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
writer.writerows(matches)
def save_json(matches: List[Dict[str, Any]], out_path: str) -> None:
import json
with open(out_path, "w", encoding="utf-8") as f:
json.dump(matches, f, ensure_ascii=False, indent=2)
def main():
parser = argparse.ArgumentParser(description="Fetch Premier League fixtures in a date range and save to CSV/JSON")
parser.add_argument("--start-date", required=True, type=iso_date, help="YYYY-MM-DD (inclusive)")
parser.add_argument("--end-date", required=True, type=iso_date, help="YYYY-MM-DD (inclusive)")
parser.add_argument("-o", "--output", required=True, help="Output file path (.csv or .json)")
args = parser.parse_args()
load_dotenv()
token = os.getenv("FOOTBALL_DATA_API_TOKEN")
if not token:
raise SystemExit("Missing FOOTBALL_DATA_API_TOKEN in environment (.env)")
data = fetch_matches(args.start_date, args.end_date, token)
matches_raw = data.get("matches", [])
matches = [normalize_match(m) for m in matches_raw]
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
ext = os.path.splitext(args.output)[1].lower()
if ext == ".csv":
save_csv(matches, args.output)
elif ext == ".json":
save_json(matches, args.output)
else:
raise SystemExit("Output must end with .csv or .json")
print(f"Saved {len(matches)} matches to {args.output}")
if __name__ == "__main__":
main()

93
src/gpt_sentiment.py Normal file
View File

@@ -0,0 +1,93 @@
import json
from typing import List, Tuple
import requests
class GPTSentiment:
"""
Minimal client for a local GPT model served by Ollama.
Expects the model to respond with a strict JSON object like:
{"label": "neg|neu|pos", "confidence": 0.0..1.0}
Endpoint used: POST {base_url}/api/generate with payload:
{"model": <model>, "prompt": <prompt>, "stream": false, "format": "json"}
"""
def __init__(self, base_url: str = "http://localhost:11434", model: str = "llama3", timeout: int = 30):
self.base_url = base_url.rstrip("/")
self.model = model
self.timeout = timeout
def _build_prompt(self, text: str) -> str:
# Keep the instruction terse and deterministic; request strict JSON.
return (
"You are a strict JSON generator for sentiment analysis. "
"Classify the INPUT text as one of: neg, neu, pos. "
"Return ONLY a JSON object with keys 'label' and 'confidence' (0..1). "
"No markdown, no prose.\n\n"
f"INPUT: {text}"
)
def _call(self, prompt: str) -> dict:
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False,
"format": "json",
}
r = requests.post(url, json=payload, timeout=self.timeout)
r.raise_for_status()
data = r.json()
# Ollama returns the model's response under 'response'
raw = data.get("response", "").strip()
try:
obj = json.loads(raw)
except Exception:
# Try to recover simple cases by stripping codefences
raw2 = raw.strip().removeprefix("```").removesuffix("```")
obj = json.loads(raw2)
return obj
@staticmethod
def _canonical_label(s: str) -> str:
s = (s or "").strip().lower()
if "neg" in s:
return "neg"
if "neu" in s or "neutral" in s:
return "neu"
if "pos" in s or "positive" in s:
return "pos"
return s or "neu"
@staticmethod
def _compound_from_label_conf(label: str, confidence: float) -> float:
label = GPTSentiment._canonical_label(label)
c = max(0.0, min(1.0, float(confidence or 0.0)))
if label == "pos":
return c
if label == "neg":
return -c
return 0.0
def predict_label_conf_batch(self, texts: List[str], batch_size: int = 8) -> Tuple[List[str], List[float]]:
labels: List[str] = []
confs: List[float] = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
for t in batch:
try:
obj = self._call(self._build_prompt(t))
lab = self._canonical_label(obj.get("label", ""))
conf = float(obj.get("confidence", 0.0))
except Exception:
lab, conf = "neu", 0.0
labels.append(lab)
confs.append(conf)
return labels, confs
def predict_compound_batch(self, texts: List[str], batch_size: int = 8) -> List[float]:
labels, confs = self.predict_label_conf_batch(texts, batch_size=batch_size)
return [self._compound_from_label_conf(lab, conf) for lab, conf in zip(labels, confs)]

65
src/make_labeling_set.py Normal file
View File

@@ -0,0 +1,65 @@
import argparse
import os
import pandas as pd
def load_messages(csv_path: str, text_col: str = 'message', extra_cols=None) -> pd.DataFrame:
if not os.path.exists(csv_path):
return pd.DataFrame()
df = pd.read_csv(csv_path)
if text_col not in df.columns:
return pd.DataFrame()
cols = ['id', text_col, 'date']
if extra_cols:
for c in extra_cols:
if c in df.columns:
cols.append(c)
cols = [c for c in cols if c in df.columns]
out = df[cols].copy()
out.rename(columns={text_col: 'message'}, inplace=True)
return out
def main():
parser = argparse.ArgumentParser(description='Create a labeling CSV from posts and/or replies.')
parser.add_argument('--posts-csv', required=False, help='Posts CSV path (e.g., data/..._update.csv)')
parser.add_argument('--replies-csv', required=False, help='Replies CSV path')
parser.add_argument('-o', '--output', default='data/labeled_sentiment.csv', help='Output CSV for labeling')
parser.add_argument('--sample-size', type=int, default=1000, help='Total rows to include (after combining)')
parser.add_argument('--min-length', type=int, default=3, help='Minimum message length to include')
parser.add_argument('--shuffle', action='store_true', help='Shuffle before sampling (default true)')
parser.add_argument('--no-shuffle', dest='shuffle', action='store_false')
parser.set_defaults(shuffle=True)
args = parser.parse_args()
frames = []
if args.posts_csv:
frames.append(load_messages(args.posts_csv))
if args.replies_csv:
# For replies, include parent_id if present
r = load_messages(args.replies_csv, extra_cols=['parent_id'])
frames.append(r)
if not frames:
raise SystemExit('No input CSVs provided. Use --posts-csv and/or --replies-csv.')
df = pd.concat(frames, ignore_index=True)
# Basic filtering: non-empty text, min length, drop duplicates by message text
df['message'] = df['message'].fillna('').astype(str)
df = df[df['message'].str.len() >= args.min_length]
df = df.drop_duplicates(subset=['message']).reset_index(drop=True)
if args.shuffle:
df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)
if args.sample_size and len(df) > args.sample_size:
df = df.head(args.sample_size)
# Add blank label column for human annotation
df.insert(1, 'label', '')
os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True)
df.to_csv(args.output, index=False)
print(f"Wrote labeling CSV with {len(df)} rows to {args.output}")
if __name__ == '__main__':
main()

137
src/plot_labeled.py Normal file
View File

@@ -0,0 +1,137 @@
import argparse
import os
from typing import Optional
import pandas as pd
def safe_read(path: str) -> pd.DataFrame:
if not os.path.exists(path):
raise SystemExit(f"Input labeled CSV not found: {path}")
df = pd.read_csv(path)
if 'label' not in df.columns:
raise SystemExit("Expected a 'label' column in the labeled CSV")
if 'message' in df.columns:
df['message'] = df['message'].fillna('').astype(str)
if 'confidence' in df.columns:
df['confidence'] = pd.to_numeric(df['confidence'], errors='coerce')
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'], errors='coerce')
return df
def ensure_out_dir(out_dir: str) -> str:
os.makedirs(out_dir, exist_ok=True)
return out_dir
def plot_all(df: pd.DataFrame, out_dir: str) -> None:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
out_dir = ensure_out_dir(out_dir)
# 1) Class distribution
try:
plt.figure(figsize=(6,4))
ax = (df['label'].astype(str).str.lower().value_counts()
.reindex(['neg','neu','pos'])
.fillna(0)
.rename_axis('label').reset_index(name='count')
.set_index('label')
.plot(kind='bar', legend=False, color=['#d62728','#aaaaaa','#2ca02c']))
plt.title('Labeled class distribution')
plt.ylabel('Count')
plt.tight_layout()
path = os.path.join(out_dir, 'labeled_class_distribution.png')
plt.savefig(path, dpi=150)
plt.close()
print(f"[plots] Saved {path}")
except Exception as e:
print(f"[plots] Skipped class distribution: {e}")
# 2) Confidence histogram (overall)
if 'confidence' in df.columns and df['confidence'].notna().any():
try:
plt.figure(figsize=(6,4))
sns.histplot(df['confidence'].dropna(), bins=30, color='#1f77b4')
plt.title('Confidence distribution (overall)')
plt.xlabel('Confidence'); plt.ylabel('Frequency')
plt.tight_layout()
path = os.path.join(out_dir, 'labeled_confidence_hist.png')
plt.savefig(path, dpi=150); plt.close()
print(f"[plots] Saved {path}")
except Exception as e:
print(f"[plots] Skipped confidence histogram: {e}")
# 3) Confidence by label (boxplot)
try:
plt.figure(figsize=(6,4))
t = df[['label','confidence']].dropna()
t['label'] = t['label'].astype(str).str.lower()
order = ['neg','neu','pos']
sns.boxplot(data=t, x='label', y='confidence', order=order, palette=['#d62728','#aaaaaa','#2ca02c'])
plt.title('Confidence by label')
plt.xlabel('Label'); plt.ylabel('Confidence')
plt.tight_layout()
path = os.path.join(out_dir, 'labeled_confidence_by_label.png')
plt.savefig(path, dpi=150); plt.close()
print(f"[plots] Saved {path}")
except Exception as e:
print(f"[plots] Skipped confidence by label: {e}")
# 4) Message length by label
if 'message' in df.columns:
try:
t = df[['label','message']].copy()
t['label'] = t['label'].astype(str).str.lower()
t['len'] = t['message'].astype(str).str.len()
plt.figure(figsize=(6,4))
sns.boxplot(data=t, x='label', y='len', order=['neg','neu','pos'], palette=['#d62728','#aaaaaa','#2ca02c'])
plt.title('Message length by label')
plt.xlabel('Label'); plt.ylabel('Length (chars)')
plt.tight_layout()
path = os.path.join(out_dir, 'labeled_length_by_label.png')
plt.savefig(path, dpi=150); plt.close()
print(f"[plots] Saved {path}")
except Exception as e:
print(f"[plots] Skipped length by label: {e}")
# 5) Daily counts per label (if date present)
if 'date' in df.columns and df['date'].notna().any():
try:
t = df[['date','label']].dropna().copy()
t['day'] = pd.to_datetime(t['date'], errors='coerce').dt.date
t['label'] = t['label'].astype(str).str.lower()
pv = t.pivot_table(index='day', columns='label', values='date', aggfunc='count').fillna(0)
# ensure consistent column order
for c in ['neg','neu','pos']:
if c not in pv.columns:
pv[c] = 0
pv = pv[['neg','neu','pos']]
import matplotlib.pyplot as plt
plt.figure(figsize=(10,4))
pv.plot(kind='bar', stacked=True, color=['#d62728','#aaaaaa','#2ca02c'])
plt.title('Daily labeled counts (stacked)')
plt.xlabel('Day'); plt.ylabel('Count')
plt.tight_layout()
path = os.path.join(out_dir, 'labeled_daily_counts.png')
plt.savefig(path, dpi=150); plt.close()
print(f"[plots] Saved {path}")
except Exception as e:
print(f"[plots] Skipped daily counts: {e}")
def main():
parser = argparse.ArgumentParser(description='Plot graphs from labeled sentiment data.')
parser.add_argument('-i', '--input', default='data/labeled_sentiment.csv', help='Path to labeled CSV')
parser.add_argument('-o', '--out-dir', default='data', help='Output directory for plots')
args = parser.parse_args()
df = safe_read(args.input)
plot_all(df, args.out_dir)
if __name__ == '__main__':
main()

749
src/telegram_scraper.py Normal file
View File

@@ -0,0 +1,749 @@
import asyncio
import json
import os
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import AsyncIterator, Iterable, Optional, Sequence, Set, List, Tuple
from dotenv import load_dotenv
from telethon import TelegramClient
from telethon.errors import SessionPasswordNeededError
from telethon.errors.rpcerrorlist import MsgIdInvalidError, FloodWaitError
from telethon.tl.functions.messages import GetDiscussionMessageRequest
from telethon.tl.custom.message import Message
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
@dataclass
class ScrapedMessage:
id: int
date: Optional[str] # ISO format
message: Optional[str]
sender_id: Optional[int]
views: Optional[int]
forwards: Optional[int]
replies: Optional[int]
url: Optional[str]
def to_iso(dt: datetime) -> str:
return dt.replace(tzinfo=None).isoformat()
async def iter_messages(
client: TelegramClient,
entity: str,
limit: Optional[int] = None,
offset_date: Optional[datetime] = None,
) -> AsyncIterator[Message]:
async for msg in client.iter_messages(entity, limit=limit, offset_date=offset_date):
yield msg
def message_to_record(msg: Message, channel_username: str) -> ScrapedMessage:
return ScrapedMessage(
id=msg.id,
date=to_iso(msg.date) if msg.date else None,
message=msg.message,
sender_id=getattr(msg.sender_id, 'value', msg.sender_id) if hasattr(msg, 'sender_id') else None,
views=getattr(msg, 'views', None),
forwards=getattr(msg, 'forwards', None),
replies=(msg.replies.replies if getattr(msg, 'replies', None) else None),
url=f"https://t.me/{channel_username}/{msg.id}" if channel_username else None,
)
async def ensure_login(client: TelegramClient, phone: Optional[str] = None, twofa_password: Optional[str] = None):
# Connect and log in, prompting interactively if needed
await client.connect()
if not await client.is_user_authorized():
if not phone:
phone = input("Enter your phone number (with country code): ")
await client.send_code_request(phone)
code = input("Enter the login code you received: ")
try:
await client.sign_in(phone=phone, code=code)
except SessionPasswordNeededError:
if twofa_password is None:
twofa_password = input("Two-step verification enabled. Enter your password: ")
await client.sign_in(password=twofa_password)
async def scrape_channel(
channel: str,
output: str,
limit: Optional[int] = None,
offset_date: Optional[str] = None, # deprecated in favor of start_date
start_date: Optional[str] = None,
end_date: Optional[str] = None,
append: bool = False,
session_name: str = "telegram",
phone: Optional[str] = None,
twofa_password: Optional[str] = None,
):
load_dotenv()
api_id = os.getenv("TELEGRAM_API_ID")
api_hash = os.getenv("TELEGRAM_API_HASH")
session_name = os.getenv("TELEGRAM_SESSION_NAME", session_name)
if not api_id or not api_hash:
raise RuntimeError("Missing TELEGRAM_API_ID/TELEGRAM_API_HASH in environment. See .env.example")
# Some providers store api_id as string; Telethon expects int
try:
api_id_int = int(api_id)
except Exception as e:
raise RuntimeError("TELEGRAM_API_ID must be an integer") from e
client = TelegramClient(session_name, api_id_int, api_hash)
# Parse date filters
parsed_start = None
parsed_end = None
if start_date:
parsed_start = datetime.fromisoformat(start_date)
elif offset_date: # backward compatibility
parsed_start = datetime.fromisoformat(offset_date)
if end_date:
parsed_end = datetime.fromisoformat(end_date)
await ensure_login(client, phone=phone, twofa_password=twofa_password)
# Determine output format based on extension
ext = os.path.splitext(output)[1].lower()
is_jsonl = ext in (".jsonl", ".ndjson")
is_csv = ext == ".csv"
if not (is_jsonl or is_csv):
raise ValueError("Output file must end with .jsonl or .csv")
# Prepare output writers
csv_file = None
csv_writer = None
jsonl_file = None
if is_csv:
import csv
mode = "a" if append else "w"
csv_file = open(output, mode, newline="", encoding="utf-8")
csv_writer = csv.DictWriter(
csv_file,
fieldnames=[
"id",
"date",
"message",
"sender_id",
"views",
"forwards",
"replies",
"url",
],
)
# Write header if not appending, or file is empty
need_header = True
try:
if append and os.path.exists(output) and os.path.getsize(output) > 0:
need_header = False
except Exception:
pass
if need_header:
csv_writer.writeheader()
elif is_jsonl:
# Open once; append or overwrite
mode = "a" if append else "w"
jsonl_file = open(output, mode, encoding="utf-8")
written = 0
try:
async for msg in iter_messages(client, channel, limit=None, offset_date=None):
# Telethon returns tz-aware datetimes; normalize for comparison
msg_dt = msg.date
if msg_dt is not None:
msg_dt = msg_dt.replace(tzinfo=None)
# Date range filter: include if within [parsed_start, parsed_end] (inclusive)
if parsed_start and msg_dt and msg_dt < parsed_start:
# Since we're iterating newest-first, once older than start we can stop
break
if parsed_end and msg_dt and msg_dt > parsed_end:
continue
rec = message_to_record(msg, channel_username=channel.lstrip("@"))
if is_jsonl and jsonl_file is not None:
jsonl_file.write(json.dumps(asdict(rec), ensure_ascii=False) + "\n")
else:
csv_writer.writerow(asdict(rec)) # type: ignore
written += 1
if limit is not None and written >= limit:
break
finally:
if csv_file:
csv_file.close()
if jsonl_file:
jsonl_file.close()
await client.disconnect()
return written
async def fetch_replies(
channel: str,
parent_ids: Sequence[int],
output_csv: str,
append: bool = False,
session_name: str = "telegram",
phone: Optional[str] = None,
twofa_password: Optional[str] = None,
concurrency: int = 5,
existing_pairs: Optional[Set[Tuple[int, int]]] = None,
):
load_dotenv()
api_id = os.getenv("TELEGRAM_API_ID")
api_hash = os.getenv("TELEGRAM_API_HASH")
session_name = os.getenv("TELEGRAM_SESSION_NAME", session_name)
if not api_id or not api_hash:
raise RuntimeError("Missing TELEGRAM_API_ID/TELEGRAM_API_HASH in environment. See .env.example")
client = TelegramClient(session_name, int(api_id), api_hash)
await ensure_login(client, phone=phone, twofa_password=twofa_password)
import csv
# Rate limiting counters
flood_hits = 0
flood_wait_seconds = 0
analyzer = SentimentIntensityAnalyzer()
os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
mode = "a" if append else "w"
with open(output_csv, mode, newline="", encoding="utf-8") as f:
writer = csv.DictWriter(
f,
fieldnames=["parent_id", "id", "date", "message", "sender_id", "sentiment_compound", "url"],
)
# Write header only if not appending or file empty
need_header = True
try:
if append and os.path.exists(output_csv) and os.path.getsize(output_csv) > 0:
need_header = False
except Exception:
pass
if need_header:
writer.writeheader()
write_lock = asyncio.Lock()
sem = asyncio.Semaphore(max(1, int(concurrency)))
async def handle_parent(pid: int) -> List[dict]:
rows: List[dict] = []
# First try replies within the same channel (works for groups/supergroups)
attempts = 0
while attempts < 3:
try:
async for reply in client.iter_messages(channel, reply_to=pid):
dt = reply.date.replace(tzinfo=None) if reply.date else None
url = f"https://t.me/{channel.lstrip('@')}/{reply.id}" if reply.id else None
text = reply.message or ""
sent = analyzer.polarity_scores(text).get("compound")
rows.append(
{
"parent_id": pid,
"id": reply.id,
"date": to_iso(dt) if dt else None,
"message": text,
"sender_id": getattr(reply, "sender_id", None),
"sentiment_compound": sent,
"url": url,
}
)
break
except FloodWaitError as e:
secs = int(getattr(e, 'seconds', 5))
flood_hits += 1
flood_wait_seconds += secs
print(f"[rate-limit] FloodWait while scanning replies in-channel for parent {pid}; waiting {secs}s", flush=True)
await asyncio.sleep(secs + 1)
attempts += 1
continue
except MsgIdInvalidError:
# Likely a channel with a linked discussion group; fall back below
rows.clear()
break
except Exception:
break
if rows:
return rows
# Fallback: for channels with comments in a linked discussion group
try:
res = await client(GetDiscussionMessageRequest(peer=channel, msg_id=pid))
except Exception:
# No discussion thread found or not accessible
return rows
# Identify the discussion chat and the root message id in that chat
disc_chat = None
if getattr(res, "chats", None):
# Prefer the first chat returned as the discussion chat
disc_chat = res.chats[0]
disc_root_id = None
for m in getattr(res, "messages", []) or []:
try:
peer_id = getattr(m, "peer_id", None)
if not peer_id or not disc_chat:
continue
ch_id = getattr(peer_id, "channel_id", None) or getattr(peer_id, "chat_id", None)
if ch_id == getattr(disc_chat, "id", None):
disc_root_id = m.id
break
except Exception:
continue
if not disc_chat or not disc_root_id:
return rows
group_username = getattr(disc_chat, "username", None)
attempts = 0
while attempts < 3:
try:
async for reply in client.iter_messages(disc_chat, reply_to=disc_root_id):
dt = reply.date.replace(tzinfo=None) if reply.date else None
text = reply.message or ""
sent = analyzer.polarity_scores(text).get("compound")
# Construct URL only if the discussion group has a public username
url = None
if group_username and reply.id:
url = f"https://t.me/{group_username}/{reply.id}"
rows.append(
{
"parent_id": pid,
"id": reply.id,
"date": to_iso(dt) if dt else None,
"message": text,
"sender_id": getattr(reply, "sender_id", None),
"sentiment_compound": sent,
"url": url,
}
)
break
except FloodWaitError as e:
secs = int(getattr(e, 'seconds', 5))
flood_hits += 1
flood_wait_seconds += secs
print(f"[rate-limit] FloodWait while scanning discussion group for parent {pid}; waiting {secs}s", flush=True)
await asyncio.sleep(secs + 1)
attempts += 1
continue
except Exception:
break
return rows
total_written = 0
processed = 0
total = len(list(parent_ids)) if hasattr(parent_ids, '__len__') else None
async def worker(pid: int):
nonlocal total_written, processed
async with sem:
rows = await handle_parent(int(pid))
async with write_lock:
if rows:
# Dedupe against existing pairs if provided (resume mode)
if existing_pairs is not None:
filtered: List[dict] = []
for r in rows:
try:
key = (int(r.get("parent_id")), int(r.get("id")))
except Exception:
continue
if key in existing_pairs:
continue
existing_pairs.add(key)
filtered.append(r)
rows = filtered
if rows:
writer.writerows(rows)
total_written += len(rows)
processed += 1
if processed % 10 == 0 or (rows and len(rows) > 0):
if total is not None:
print(f"[replies] processed {processed}/{total} parents; last parent {pid} wrote {len(rows)} replies; total replies {total_written}", flush=True)
else:
print(f"[replies] processed {processed} parents; last parent {pid} wrote {len(rows)} replies; total replies {total_written}", flush=True)
tasks = [asyncio.create_task(worker(pid)) for pid in parent_ids]
await asyncio.gather(*tasks)
await client.disconnect()
if flood_hits:
print(f"[rate-limit] Summary: {flood_hits} FloodWait events; total waited ~{flood_wait_seconds}s", flush=True)
async def fetch_forwards(
channel: str,
parent_ids: Set[int],
output_csv: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
scan_limit: Optional[int] = None,
append: bool = False,
session_name: str = "telegram",
phone: Optional[str] = None,
twofa_password: Optional[str] = None,
concurrency: int = 5,
chunk_size: int = 1000,
):
"""Best-effort: find forwarded messages within the SAME channel that reference the given parent_ids.
Telegram API does not provide a global reverse-lookup of forwards across all channels; we therefore scan
this channel's history and collect messages with fwd_from.channel_post matching a parent id.
"""
load_dotenv()
api_id = os.getenv("TELEGRAM_API_ID")
api_hash = os.getenv("TELEGRAM_API_HASH")
session_name = os.getenv("TELEGRAM_SESSION_NAME", session_name)
if not api_id or not api_hash:
raise RuntimeError("Missing TELEGRAM_API_ID/TELEGRAM_API_HASH in environment. See .env.example")
client = TelegramClient(session_name, int(api_id), api_hash)
await ensure_login(client, phone=phone, twofa_password=twofa_password)
import csv
# Rate limiting counters
flood_hits = 0
import csv
analyzer = SentimentIntensityAnalyzer()
os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
mode = "a" if append else "w"
write_lock = asyncio.Lock()
with open(output_csv, mode, newline="", encoding="utf-8") as f:
writer = csv.DictWriter(
f,
fieldnames=["parent_id", "id", "date", "message", "sender_id", "sentiment_compound", "url"],
)
need_header = True
try:
if append and os.path.exists(output_csv) and os.path.getsize(output_csv) > 0:
need_header = False
except Exception:
pass
if need_header:
writer.writeheader()
parsed_start = datetime.fromisoformat(start_date) if start_date else None
parsed_end = datetime.fromisoformat(end_date) if end_date else None
# If no scan_limit provided, fall back to sequential scan to avoid unbounded concurrency
if scan_limit is None:
scanned = 0
matched = 0
async for msg in client.iter_messages(channel, limit=None):
dt = msg.date.replace(tzinfo=None) if msg.date else None
if parsed_start and dt and dt < parsed_start:
break
if parsed_end and dt and dt > parsed_end:
continue
fwd = getattr(msg, "fwd_from", None)
if not fwd:
continue
ch_post = getattr(fwd, "channel_post", None)
if ch_post and int(ch_post) in parent_ids:
text = msg.message or ""
sent = analyzer.polarity_scores(text).get("compound")
url = f"https://t.me/{channel.lstrip('@')}/{msg.id}" if msg.id else None
writer.writerow(
{
"parent_id": int(ch_post),
"id": msg.id,
"date": to_iso(dt) if dt else None,
"message": text,
"sender_id": getattr(msg, "sender_id", None),
"sentiment_compound": sent,
"url": url,
}
)
matched += 1
scanned += 1
if scanned % 1000 == 0:
print(f"[forwards] scanned ~{scanned} messages; total forwards {matched}", flush=True)
else:
# Concurrent chunked scanning by id ranges
# Rate limiting counters
flood_hits = 0
flood_wait_seconds = 0
sem = asyncio.Semaphore(max(1, int(concurrency)))
progress_lock = asyncio.Lock()
matched_total = 0
completed_chunks = 0
# Determine latest message id
latest_msg = await client.get_messages(channel, limit=1)
latest_id = None
try:
latest_id = getattr(latest_msg, 'id', None) or (latest_msg[0].id if latest_msg else None)
except Exception:
latest_id = None
if not latest_id:
await client.disconnect()
return
total_chunks = max(1, (int(scan_limit) + int(chunk_size) - 1) // int(chunk_size))
async def process_chunk(idx: int):
nonlocal flood_hits, flood_wait_seconds
nonlocal matched_total, completed_chunks
max_id = latest_id - idx * int(chunk_size)
min_id = max(0, max_id - int(chunk_size))
attempts = 0
local_matches = 0
while attempts < 3:
try:
async with sem:
async for msg in client.iter_messages(channel, min_id=min_id, max_id=max_id):
dt = msg.date.replace(tzinfo=None) if msg.date else None
if parsed_start and dt and dt < parsed_start:
# This range reached before start; skip remaining in this chunk
break
if parsed_end and dt and dt > parsed_end:
continue
fwd = getattr(msg, "fwd_from", None)
if not fwd:
continue
ch_post = getattr(fwd, "channel_post", None)
if ch_post and int(ch_post) in parent_ids:
text = msg.message or ""
sent = analyzer.polarity_scores(text).get("compound")
url = f"https://t.me/{channel.lstrip('@')}/{msg.id}" if msg.id else None
async with write_lock:
writer.writerow(
{
"parent_id": int(ch_post),
"id": msg.id,
"date": to_iso(dt) if dt else None,
"message": text,
"sender_id": getattr(msg, "sender_id", None),
"sentiment_compound": sent,
"url": url,
}
)
local_matches += 1
break
except FloodWaitError as e:
secs = int(getattr(e, 'seconds', 5))
flood_hits += 1
flood_wait_seconds += secs
print(f"[rate-limit] FloodWait while scanning ids {min_id}-{max_id}; waiting {secs}s", flush=True)
await asyncio.sleep(secs + 1)
attempts += 1
continue
except Exception:
# best-effort; skip this chunk
break
async with progress_lock:
matched_total += local_matches
completed_chunks += 1
print(
f"[forwards] chunks {completed_chunks}/{total_chunks}; last {min_id}-{max_id} wrote {local_matches} forwards; total forwards {matched_total}",
flush=True,
)
tasks = [asyncio.create_task(process_chunk(i)) for i in range(total_chunks)]
await asyncio.gather(*tasks)
await client.disconnect()
# Print summary if we used concurrent chunking
try:
if scan_limit is not None and 'flood_hits' in locals() and flood_hits:
print(f"[rate-limit] Summary: {flood_hits} FloodWait events; total waited ~{flood_wait_seconds}s", flush=True)
except Exception:
pass
def main():
import argparse
parser = argparse.ArgumentParser(description="Telegram scraper utilities")
sub = parser.add_subparsers(dest="command", required=True)
# Subcommand: scrape channel history
p_scrape = sub.add_parser("scrape", help="Scrape messages from a channel")
p_scrape.add_argument("channel", help="Channel username or t.me link, e.g. @python, https://t.me/python")
p_scrape.add_argument("--output", "-o", required=True, help="Output file (.jsonl or .csv)")
p_scrape.add_argument("--limit", type=int, default=None, help="Max number of messages to save after filtering")
p_scrape.add_argument("--offset-date", dest="offset_date", default=None, help="Deprecated: use --start-date instead. ISO date (inclusive)")
p_scrape.add_argument("--start-date", dest="start_date", default=None, help="ISO start date (inclusive)")
p_scrape.add_argument("--end-date", dest="end_date", default=None, help="ISO end date (inclusive)")
p_scrape.add_argument("--append", action="store_true", help="Append to the output file instead of overwriting")
p_scrape.add_argument("--session-name", default=os.getenv("TELEGRAM_SESSION_NAME", "telegram"))
p_scrape.add_argument("--phone", default=None)
p_scrape.add_argument("--twofa-password", default=os.getenv("TELEGRAM_2FA_PASSWORD"))
# Subcommand: fetch replies for specific message ids
p_rep = sub.add_parser("replies", help="Fetch replies for given message IDs and save to CSV")
p_rep.add_argument("channel", help="Channel username or t.me link")
src = p_rep.add_mutually_exclusive_group(required=True)
src.add_argument("--ids", help="Comma-separated parent message IDs")
src.add_argument("--from-csv", dest="from_csv", help="Path to CSV with an 'id' column to use as parent IDs")
p_rep.add_argument("--output", "-o", required=True, help="Output CSV path (e.g., data/replies_channel.csv)")
p_rep.add_argument("--append", action="store_true", help="Append to the output file instead of overwriting")
p_rep.add_argument("--session-name", default=os.getenv("TELEGRAM_SESSION_NAME", "telegram"))
p_rep.add_argument("--phone", default=None)
p_rep.add_argument("--twofa-password", default=os.getenv("TELEGRAM_2FA_PASSWORD"))
p_rep.add_argument("--concurrency", type=int, default=5, help="Number of parent IDs to process in parallel (default 5)")
p_rep.add_argument("--min-replies", type=int, default=None, help="When using --from-csv, only process parents with replies >= this value")
p_rep.add_argument("--resume", action="store_true", help="Resume mode: skip parent_id,id pairs already present in the output CSV")
# Subcommand: fetch forwards (same-channel forwards referencing parent ids)
p_fwd = sub.add_parser("forwards", help="Best-effort: find forwards within the same channel for given parent IDs")
p_fwd.add_argument("channel", help="Channel username or t.me link")
src2 = p_fwd.add_mutually_exclusive_group(required=True)
src2.add_argument("--ids", help="Comma-separated parent message IDs")
src2.add_argument("--from-csv", dest="from_csv", help="Path to CSV with an 'id' column to use as parent IDs")
p_fwd.add_argument("--output", "-o", required=True, help="Output CSV path (e.g., data/forwards_channel.csv)")
p_fwd.add_argument("--start-date", dest="start_date", default=None)
p_fwd.add_argument("--end-date", dest="end_date", default=None)
p_fwd.add_argument("--scan-limit", dest="scan_limit", type=int, default=None, help="Max messages to scan in channel history")
p_fwd.add_argument("--concurrency", type=int, default=5, help="Number of id-chunks to scan in parallel (requires --scan-limit)")
p_fwd.add_argument("--chunk-size", dest="chunk_size", type=int, default=1000, help="Approx. messages per chunk (ids)")
p_fwd.add_argument("--append", action="store_true", help="Append to the output file instead of overwriting")
p_fwd.add_argument("--session-name", default=os.getenv("TELEGRAM_SESSION_NAME", "telegram"))
p_fwd.add_argument("--phone", default=None)
p_fwd.add_argument("--twofa-password", default=os.getenv("TELEGRAM_2FA_PASSWORD"))
args = parser.parse_args()
# Normalize channel
channel = getattr(args, "channel", None)
if channel and channel.startswith("https://t.me/"):
channel = channel.replace("https://t.me/", "@")
def _normalize_handle(ch: Optional[str]) -> Optional[str]:
if not ch:
return ch
# Expect inputs like '@name' or 'name'; return lowercase without leading '@'
return ch.lstrip('@').lower()
def _extract_handle_from_url(url: str) -> Optional[str]:
try:
if not url:
return None
# Accept forms like https://t.me/Name/123 or http(s)://t.me/c/<id>/<msg>
# Only public usernames (not /c/ links) can be compared reliably
if "/t.me/" in url:
# crude parse without urlparse to avoid dependency
after = url.split("t.me/")[-1]
parts = after.split('/')
if parts and parts[0] and parts[0] != 'c':
return parts[0]
except Exception:
return None
return None
if args.command == "scrape":
written = asyncio.run(
scrape_channel(
channel=channel,
output=args.output,
limit=args.limit,
offset_date=args.offset_date,
start_date=args.start_date,
end_date=args.end_date,
append=getattr(args, "append", False),
session_name=args.session_name,
phone=args.phone,
twofa_password=args.twofa_password,
)
)
print(f"Wrote {written} messages to {args.output}")
elif args.command == "replies":
# If using --from-csv, try to infer channel from URLs and warn on mismatch
try:
if getattr(args, 'from_csv', None):
import pandas as _pd # local import to keep startup light
# Read a small sample of URL column to detect handle
sample = _pd.read_csv(args.from_csv, usecols=['url'], nrows=20)
url_handles = [
_extract_handle_from_url(str(u)) for u in sample['url'].dropna().tolist() if isinstance(u, (str,))
]
inferred = next((h for h in url_handles if h), None)
provided = _normalize_handle(channel)
if inferred and provided and _normalize_handle(inferred) != provided:
print(
f"[warning] CSV appears to be from @{_normalize_handle(inferred)} but you passed -c @{provided}. "
f"Replies may be empty. Consider using -c https://t.me/{inferred}",
flush=True,
)
except Exception:
# Best-effort only; ignore any issues reading/inspecting CSV
pass
parent_ids: Set[int]
if getattr(args, "ids", None):
parent_ids = {int(x.strip()) for x in args.ids.split(",") if x.strip()}
else:
import pandas as pd # local import
usecols = ['id']
if args.min_replies is not None:
usecols.append('replies')
df = pd.read_csv(args.from_csv, usecols=usecols)
if args.min_replies is not None and 'replies' in df.columns:
df = df[df['replies'].fillna(0).astype(int) >= int(args.min_replies)]
parent_ids = set(int(x) for x in df['id'].dropna().astype(int).tolist())
existing_pairs = None
if args.resume and os.path.exists(args.output):
try:
import csv as _csv
existing_pairs = set()
with open(args.output, "r", encoding="utf-8") as _f:
reader = _csv.DictReader(_f)
for row in reader:
try:
existing_pairs.add((int(row.get("parent_id")), int(row.get("id"))))
except Exception:
continue
except Exception:
existing_pairs = None
asyncio.run(
fetch_replies(
channel=channel,
parent_ids=sorted(parent_ids),
output_csv=args.output,
append=getattr(args, "append", False),
session_name=args.session_name,
phone=args.phone,
twofa_password=args.twofa_password,
concurrency=max(1, int(getattr(args, 'concurrency', 5))),
existing_pairs=existing_pairs,
)
)
print(f"Saved replies to {args.output}")
elif args.command == "forwards":
parent_ids: Set[int]
if getattr(args, "ids", None):
parent_ids = {int(x.strip()) for x in args.ids.split(",") if x.strip()}
else:
import pandas as pd
df = pd.read_csv(args.from_csv)
parent_ids = set(int(x) for x in df['id'].dropna().astype(int).tolist())
asyncio.run(
fetch_forwards(
channel=channel,
parent_ids=parent_ids,
output_csv=args.output,
start_date=args.start_date,
end_date=args.end_date,
scan_limit=args.scan_limit,
concurrency=max(1, int(getattr(args, 'concurrency', 5))),
chunk_size=max(1, int(getattr(args, 'chunk_size', 1000))),
append=getattr(args, "append", False),
session_name=args.session_name,
phone=args.phone,
twofa_password=args.twofa_password,
)
)
print(f"Saved forwards to {args.output}")
if __name__ == "__main__":
main()

135
src/train_sentiment.py Normal file
View File

@@ -0,0 +1,135 @@
import argparse
import os
from typing import Optional
import pandas as pd
from datasets import Dataset, ClassLabel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import inspect
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
def build_dataset(df: pd.DataFrame, text_col: str, label_col: str, label_mapping: Optional[dict] = None) -> Dataset:
d = df[[text_col, label_col]].dropna().copy()
# Normalize and drop empty labels
d[label_col] = d[label_col].astype(str).str.strip()
d = d[d[label_col] != '']
if d.empty:
raise SystemExit("No labeled rows found. Please fill the 'label' column in your CSV (e.g., neg/neu/pos or 0/1/2).")
if label_mapping:
d[label_col] = d[label_col].map(label_mapping)
# If labels are strings, factorize them
if d[label_col].dtype == object:
d[label_col] = d[label_col].astype('category')
label2id = {k: int(v) for v, k in enumerate(d[label_col].cat.categories)}
id2label = {v: k for k, v in label2id.items()}
d[label_col] = d[label_col].cat.codes
else:
# Assume numeric 0..N-1
classes = sorted(d[label_col].unique().tolist())
label2id = {str(c): int(c) for c in classes}
id2label = {int(c): str(c) for c in classes}
hf = Dataset.from_pandas(d.reset_index(drop=True))
hf = hf.class_encode_column(label_col)
hf.features[label_col] = ClassLabel(num_classes=len(id2label), names=[id2label[i] for i in range(len(id2label))])
return hf, label2id, id2label
def tokenize_fn(examples, tokenizer, text_col):
return tokenizer(examples[text_col], truncation=True, padding=False)
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return {
'accuracy': accuracy_score(labels, preds),
'precision_macro': precision_score(labels, preds, average='macro', zero_division=0),
'recall_macro': recall_score(labels, preds, average='macro', zero_division=0),
'f1_macro': f1_score(labels, preds, average='macro', zero_division=0),
}
def main():
parser = argparse.ArgumentParser(description='Fine-tune a transformers model for sentiment.')
parser.add_argument('--train-csv', required=True, help='Path to labeled CSV')
parser.add_argument('--text-col', default='message', help='Text column name')
parser.add_argument('--label-col', default='label', help='Label column name (e.g., pos/neu/neg or 2/1/0)')
parser.add_argument('--model-name', default='distilbert-base-uncased', help='Base model name or path')
parser.add_argument('--output-dir', default='models/sentiment-distilbert', help='Where to save the fine-tuned model')
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--eval-split', type=float, default=0.1, help='Fraction of data for eval')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
df = pd.read_csv(args.train_csv)
ds, label2id, id2label = build_dataset(df, args.text_col, args.label_col)
if args.eval_split > 0:
ds = ds.train_test_split(test_size=args.eval_split, seed=42, stratify_by_column=args.label_col)
train_ds, eval_ds = ds['train'], ds['test']
else:
train_ds, eval_ds = ds, None
num_labels = len(id2label)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=num_labels,
id2label=id2label,
label2id={k: int(v) for k, v in label2id.items()},
)
tokenized_train = train_ds.map(lambda x: tokenize_fn(x, tokenizer, args.text_col), batched=True)
tokenized_eval = eval_ds.map(lambda x: tokenize_fn(x, tokenizer, args.text_col), batched=True) if (eval_ds is not None) else None
# Build TrainingArguments with compatibility across transformers versions
base_kwargs = {
'output_dir': args.output_dir,
'per_device_train_batch_size': args.batch_size,
'per_device_eval_batch_size': args.batch_size,
'num_train_epochs': args.epochs,
'learning_rate': args.lr,
'fp16': False,
'logging_steps': 50,
}
eval_kwargs = {}
if tokenized_eval is not None:
# Set both evaluation_strategy and eval_strategy for compatibility across transformers versions
eval_kwargs.update({
'evaluation_strategy': 'epoch',
'eval_strategy': 'epoch',
'save_strategy': 'epoch',
'load_best_model_at_end': True,
'metric_for_best_model': 'f1_macro',
'greater_is_better': True,
})
# Filter kwargs to only include parameters supported by this transformers version
sig = inspect.signature(TrainingArguments.__init__)
allowed = set(sig.parameters.keys())
def _filter(d: dict) -> dict:
return {k: v for k, v in d.items() if k in allowed}
training_args = TrainingArguments(**_filter(base_kwargs), **_filter(eval_kwargs))
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval,
tokenizer=tokenizer,
compute_metrics=compute_metrics if tokenized_eval else None,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
print(f"Model saved to {args.output_dir}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,90 @@
from typing import List, Optional
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
class TransformerSentiment:
def __init__(self, model_name_or_path: str, device: Optional[str] = None, max_length: int = 256):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
self.max_length = max_length
if device is None:
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
self.device = device
self.model.to(self.device)
self.model.eval()
# Expect labels roughly like {0:'neg',1:'neu',2:'pos'} or similar
self.id2label = self.model.config.id2label if hasattr(self.model.config, 'id2label') else {0:'0',1:'1',2:'2'}
def _compound_from_probs(self, probs: np.ndarray) -> float:
# Map class probabilities to a [-1,1] compound-like score.
# If we have exactly 3 labels and names look like neg/neu/pos (any case), use that mapping.
labels = [self.id2label.get(i, str(i)).lower() for i in range(len(probs))]
try:
neg_idx = labels.index('neg') if 'neg' in labels else labels.index('negative')
except ValueError:
neg_idx = 0
try:
neu_idx = labels.index('neu') if 'neu' in labels else labels.index('neutral')
except ValueError:
neu_idx = 1 if len(probs) > 2 else None
try:
pos_idx = labels.index('pos') if 'pos' in labels else labels.index('positive')
except ValueError:
pos_idx = (len(probs)-1)
p_neg = float(probs[neg_idx]) if neg_idx is not None else 0.0
p_pos = float(probs[pos_idx]) if pos_idx is not None else 0.0
# A simple skew: pos - neg; keep within [-1,1]
comp = max(-1.0, min(1.0, p_pos - p_neg))
return comp
@torch.no_grad()
def predict_compound_batch(self, texts: List[str], batch_size: int = 32) -> List[float]:
out: List[float] = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
enc = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
enc = {k: v.to(self.device) for k, v in enc.items()}
logits = self.model(**enc).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()
for row in probs:
out.append(self._compound_from_probs(row))
return out
@torch.no_grad()
def predict_probs_and_labels(self, texts: List[str], batch_size: int = 32):
probs_all = []
labels_all: List[str] = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
enc = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
enc = {k: v.to(self.device) for k, v in enc.items()}
logits = self.model(**enc).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()
preds = probs.argmax(axis=-1)
for j, row in enumerate(probs):
probs_all.append(row)
label = self.id2label.get(int(preds[j]), str(int(preds[j])))
labels_all.append(label)
return probs_all, labels_all