chore(repo): initialize git with .gitignore, .gitattributes, and project sources
This commit is contained in:
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# This file is intentionally left blank.
|
||||
1313
src/analyze_csv.py
Normal file
1313
src/analyze_csv.py
Normal file
File diff suppressed because it is too large
Load Diff
50
src/apply_labels.py
Normal file
50
src/apply_labels.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def read_csv(path: str) -> pd.DataFrame:
|
||||
if not os.path.exists(path):
|
||||
raise SystemExit(f"CSV not found: {path}")
|
||||
return pd.read_csv(path)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description='Apply labeled sentiments to posts/replies CSVs for analysis plots.')
|
||||
p.add_argument('--labeled-csv', required=True, help='Path to labeled_sentiment.csv (must include id and label columns)')
|
||||
p.add_argument('--posts-csv', required=True, help='Original posts CSV')
|
||||
p.add_argument('--replies-csv', required=True, help='Original replies CSV')
|
||||
p.add_argument('--posts-out', default=None, help='Output posts CSV path (default: <posts> with _with_labels suffix)')
|
||||
p.add_argument('--replies-out', default=None, help='Output replies CSV path (default: <replies> with _with_labels suffix)')
|
||||
args = p.parse_args()
|
||||
|
||||
labeled = read_csv(args.labeled_csv)
|
||||
if 'id' not in labeled.columns:
|
||||
raise SystemExit('labeled CSV must include an id column to merge on')
|
||||
# normalize label column name to sentiment_label
|
||||
lab_col = 'label' if 'label' in labeled.columns else ('sentiment_label' if 'sentiment_label' in labeled.columns else None)
|
||||
if lab_col is None:
|
||||
raise SystemExit("labeled CSV must include a 'label' or 'sentiment_label' column")
|
||||
labeled = labeled[['id', lab_col] + (['confidence'] if 'confidence' in labeled.columns else [])].copy()
|
||||
labeled = labeled.rename(columns={lab_col: 'sentiment_label'})
|
||||
|
||||
posts = read_csv(args.posts_csv)
|
||||
replies = read_csv(args.replies_csv)
|
||||
|
||||
if 'id' not in posts.columns or 'id' not in replies.columns:
|
||||
raise SystemExit('posts/replies CSVs must include id columns')
|
||||
|
||||
posts_out = args.posts_out or os.path.splitext(args.posts_csv)[0] + '_with_labels.csv'
|
||||
replies_out = args.replies_out or os.path.splitext(args.replies_csv)[0] + '_with_labels.csv'
|
||||
|
||||
posts_merged = posts.merge(labeled, how='left', on='id', validate='m:1')
|
||||
replies_merged = replies.merge(labeled, how='left', on='id', validate='m:1')
|
||||
|
||||
posts_merged.to_csv(posts_out, index=False)
|
||||
replies_merged.to_csv(replies_out, index=False)
|
||||
print(f"Wrote posts with labels -> {posts_out} (rows={len(posts_merged)})")
|
||||
print(f"Wrote replies with labels -> {replies_out} (rows={len(replies_merged)})")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
107
src/audit_team_sentiment.py
Normal file
107
src/audit_team_sentiment.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
|
||||
|
||||
|
||||
def parse_tags_column(series: pd.Series) -> pd.Series:
|
||||
def _to_list(x):
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
if pd.isna(x):
|
||||
return []
|
||||
s = str(x)
|
||||
# Expect semicolon-delimited from augmented CSV, but also accept comma
|
||||
if ';' in s:
|
||||
return [t.strip() for t in s.split(';') if t.strip()]
|
||||
if ',' in s:
|
||||
return [t.strip() for t in s.split(',') if t.strip()]
|
||||
return [s] if s else []
|
||||
return series.apply(_to_list)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Audit sentiment per team tag and export samples for inspection.')
|
||||
parser.add_argument('--csv', default='data/premier_league_update_tagged.csv', help='Tagged posts CSV (augmented by analyze)')
|
||||
parser.add_argument('--team', default='club_manchester_united', help='Team tag to export samples for (e.g., club_manchester_united)')
|
||||
parser.add_argument('--out-dir', default='data', help='Directory to write audit outputs')
|
||||
parser.add_argument('--samples', type=int, default=25, help='Number of samples to export for the specified team')
|
||||
parser.add_argument('--with-vader', action='store_true', help='Also compute VADER-based sentiment shares as a sanity check')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.csv):
|
||||
raise SystemExit(f"CSV not found: {args.csv}. Run analyze with --write-augmented-csv first.")
|
||||
|
||||
df = pd.read_csv(args.csv)
|
||||
if 'message' not in df.columns:
|
||||
raise SystemExit('CSV missing message column')
|
||||
if 'sentiment_compound' not in df.columns:
|
||||
raise SystemExit('CSV missing sentiment_compound column')
|
||||
if 'tags' not in df.columns:
|
||||
raise SystemExit('CSV missing tags column')
|
||||
|
||||
df = df.copy()
|
||||
df['tags'] = parse_tags_column(df['tags'])
|
||||
# Filter to team tags (prefix club_)
|
||||
e = df.explode('tags')
|
||||
e = e[e['tags'].notna() & (e['tags'] != '')]
|
||||
e = e[e['tags'].astype(str).str.startswith('club_')]
|
||||
e = e.dropna(subset=['sentiment_compound'])
|
||||
if e.empty:
|
||||
print('No team-tagged rows found.')
|
||||
return
|
||||
|
||||
# Shares
|
||||
e = e.copy()
|
||||
e['is_pos'] = e['sentiment_compound'] > 0.05
|
||||
e['is_neg'] = e['sentiment_compound'] < -0.05
|
||||
grp = (
|
||||
e.groupby('tags')
|
||||
.agg(
|
||||
n=('sentiment_compound', 'count'),
|
||||
mean=('sentiment_compound', 'mean'),
|
||||
median=('sentiment_compound', 'median'),
|
||||
pos_share=('is_pos', 'mean'),
|
||||
neg_share=('is_neg', 'mean'),
|
||||
)
|
||||
.reset_index()
|
||||
)
|
||||
grp['neu_share'] = (1 - grp['pos_share'] - grp['neg_share']).clip(lower=0)
|
||||
grp = grp.sort_values(['n', 'mean'], ascending=[False, False])
|
||||
|
||||
if args.with_vader:
|
||||
# Compute VADER shares on the underlying messages per team
|
||||
analyzer = SentimentIntensityAnalyzer()
|
||||
def _vader_sentiment_share(sub: pd.DataFrame):
|
||||
if sub.empty:
|
||||
return pd.Series({'pos_share_vader': 0.0, 'neg_share_vader': 0.0, 'neu_share_vader': 0.0})
|
||||
scores = sub['message'].astype(str).apply(lambda t: analyzer.polarity_scores(t or '')['compound'])
|
||||
pos = (scores > 0.05).mean()
|
||||
neg = (scores < -0.05).mean()
|
||||
neu = max(0.0, 1.0 - pos - neg)
|
||||
return pd.Series({'pos_share_vader': pos, 'neg_share_vader': neg, 'neu_share_vader': neu})
|
||||
vader_grp = e.groupby('tags').apply(_vader_sentiment_share).reset_index()
|
||||
grp = grp.merge(vader_grp, on='tags', how='left')
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
out_summary = os.path.join(args.out_dir, 'team_sentiment_audit.csv')
|
||||
grp.to_csv(out_summary, index=False)
|
||||
print(f"Wrote summary: {out_summary}")
|
||||
|
||||
# Export samples for selected team
|
||||
te = e[e['tags'] == args.team].copy()
|
||||
if te.empty:
|
||||
print(f"No rows for team tag: {args.team}")
|
||||
return
|
||||
# Sort by sentiment descending to inspect highly positive claims
|
||||
te = te.sort_values('sentiment_compound', ascending=False)
|
||||
cols = [c for c in ['id', 'date', 'message', 'sentiment_compound', 'url'] if c in te.columns]
|
||||
samples_path = os.path.join(args.out_dir, f"{args.team}_samples.csv")
|
||||
te[cols].head(args.samples).to_csv(samples_path, index=False)
|
||||
print(f"Wrote samples: {samples_path} ({min(args.samples, len(te))} rows)")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
218
src/auto_label_sentiment.py
Normal file
218
src/auto_label_sentiment.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
|
||||
|
||||
try:
|
||||
# Allow both package and direct script execution
|
||||
from .make_labeling_set import load_messages as _load_messages
|
||||
except Exception:
|
||||
from make_labeling_set import load_messages as _load_messages
|
||||
|
||||
|
||||
def _combine_inputs(posts_csv: Optional[str], replies_csv: Optional[str], text_col: str = 'message', min_length: int = 3) -> pd.DataFrame:
|
||||
frames: List[pd.DataFrame] = []
|
||||
if posts_csv:
|
||||
frames.append(_load_messages(posts_csv, text_col=text_col))
|
||||
if replies_csv:
|
||||
# include parent_id if present for replies
|
||||
frames.append(_load_messages(replies_csv, text_col=text_col, extra_cols=['parent_id']))
|
||||
if not frames:
|
||||
raise SystemExit('No input provided. Use --input-csv or --posts-csv/--replies-csv')
|
||||
df = pd.concat(frames, ignore_index=True)
|
||||
df['message'] = df['message'].fillna('').astype(str)
|
||||
df = df[df['message'].str.len() >= min_length]
|
||||
df = df.drop_duplicates(subset=['message']).reset_index(drop=True)
|
||||
return df
|
||||
|
||||
|
||||
def _map_label_str_to_int(labels: List[str]) -> List[int]:
|
||||
mapping = {'neg': 0, 'negative': 0, 'neu': 1, 'neutral': 1, 'pos': 2, 'positive': 2}
|
||||
out: List[int] = []
|
||||
for lab in labels:
|
||||
lab_l = (lab or '').lower()
|
||||
if lab_l in mapping:
|
||||
out.append(mapping[lab_l])
|
||||
else:
|
||||
# fallback: try to parse integer
|
||||
try:
|
||||
out.append(int(lab))
|
||||
except Exception:
|
||||
out.append(1) # default to neutral
|
||||
return out
|
||||
|
||||
|
||||
def _vader_label(compound: float, pos_th: float, neg_th: float) -> str:
|
||||
if compound >= pos_th:
|
||||
return 'pos'
|
||||
if compound <= neg_th:
|
||||
return 'neg'
|
||||
return 'neu'
|
||||
|
||||
|
||||
def _auto_label_vader(texts: List[str], pos_th: float, neg_th: float, min_margin: float) -> Tuple[List[str], List[float]]:
|
||||
analyzer = SentimentIntensityAnalyzer()
|
||||
labels: List[str] = []
|
||||
confs: List[float] = []
|
||||
for t in texts:
|
||||
s = analyzer.polarity_scores(t or '')
|
||||
comp = float(s.get('compound', 0.0))
|
||||
lab = _vader_label(comp, pos_th, neg_th)
|
||||
# Confidence heuristic: distance from neutral band edges
|
||||
if lab == 'pos':
|
||||
conf = max(0.0, comp - pos_th)
|
||||
elif lab == 'neg':
|
||||
conf = max(0.0, abs(comp - neg_th))
|
||||
else:
|
||||
# closer to 0 is more neutral; confidence inversely related to |compound|
|
||||
conf = max(0.0, (pos_th - abs(comp)))
|
||||
labels.append(lab)
|
||||
confs.append(conf)
|
||||
# Normalize confidence roughly to [0,1] by clipping with a reasonable scale
|
||||
confs = [min(1.0, c / max(1e-6, min_margin)) for c in confs]
|
||||
return labels, confs
|
||||
|
||||
|
||||
def _auto_label_transformers(texts: List[str], model_name_or_path: str, batch_size: int, min_prob: float, min_margin: float) -> Tuple[List[str], List[float]]:
|
||||
try:
|
||||
from .transformer_sentiment import TransformerSentiment
|
||||
except Exception:
|
||||
from transformer_sentiment import TransformerSentiment
|
||||
|
||||
clf = TransformerSentiment(model_name_or_path)
|
||||
probs_all, labels_all = clf.predict_probs_and_labels(texts, batch_size=batch_size)
|
||||
confs: List[float] = []
|
||||
for row in probs_all:
|
||||
row = np.array(row, dtype=float)
|
||||
if row.size == 0:
|
||||
confs.append(0.0)
|
||||
continue
|
||||
top2 = np.sort(row)[-2:] if row.size >= 2 else np.array([0.0, row.max()])
|
||||
max_p = float(row.max())
|
||||
margin = float(top2[-1] - top2[-2]) if row.size >= 2 else max_p
|
||||
# Confidence must satisfy both conditions
|
||||
conf = min(max(0.0, (max_p - min_prob) / max(1e-6, 1 - min_prob)), max(0.0, margin / max(1e-6, min_margin)))
|
||||
confs.append(conf)
|
||||
# Map arbitrary id2label names to canonical 'neg/neu/pos' when obvious; else keep as-is
|
||||
canonical = []
|
||||
for lab in labels_all:
|
||||
ll = (lab or '').lower()
|
||||
if 'neg' in ll:
|
||||
canonical.append('neg')
|
||||
elif 'neu' in ll or 'neutral' in ll:
|
||||
canonical.append('neu')
|
||||
elif 'pos' in ll or 'positive' in ll:
|
||||
canonical.append('pos')
|
||||
else:
|
||||
canonical.append(lab)
|
||||
return canonical, confs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Automatically label sentiment without manual annotation.')
|
||||
src = parser.add_mutually_exclusive_group(required=True)
|
||||
src.add_argument('--input-csv', help='Single CSV containing a text column (default: message)')
|
||||
src.add_argument('--posts-csv', help='Posts CSV to include')
|
||||
parser.add_argument('--replies-csv', help='Replies CSV to include (combined with posts if provided)')
|
||||
parser.add_argument('--text-col', default='message', help='Text column name in input CSV(s)')
|
||||
parser.add_argument('-o', '--output', default='data/labeled_sentiment.csv', help='Output labeled CSV path')
|
||||
parser.add_argument('--limit', type=int, default=None, help='Optional cap on number of rows')
|
||||
parser.add_argument('--min-length', type=int, default=3, help='Minimum text length to consider')
|
||||
|
||||
parser.add_argument('--backend', choices=['vader', 'transformers', 'gpt'], default='vader', help='Labeling backend: vader, transformers, or gpt (local via Ollama)')
|
||||
# VADER knobs
|
||||
parser.add_argument('--vader-pos', type=float, default=0.05, help='VADER positive threshold (compound >=)')
|
||||
parser.add_argument('--vader-neg', type=float, default=-0.05, help='VADER negative threshold (compound <=)')
|
||||
parser.add_argument('--vader-margin', type=float, default=0.2, help='Confidence scaling for VADER distance')
|
||||
# Transformers knobs
|
||||
parser.add_argument('--transformers-model', default='cardiffnlp/twitter-roberta-base-sentiment-latest', help='HF model for 3-class sentiment')
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--min-prob', type=float, default=0.6, help='Min top class probability to accept')
|
||||
parser.add_argument('--min-margin', type=float, default=0.2, help='Min prob gap between top-1 and top-2 to accept')
|
||||
|
||||
# GPT knobs
|
||||
parser.add_argument('--gpt-model', default='llama3', help='Local GPT model name (Ollama)')
|
||||
parser.add_argument('--gpt-base-url', default='http://localhost:11434', help='Base URL for local GPT server (Ollama)')
|
||||
parser.add_argument('--gpt-batch-size', type=int, default=8)
|
||||
|
||||
parser.add_argument('--label-format', choices=['str', 'int'], default='str', help="Output labels as strings ('neg/neu/pos') or integers (0/1/2)")
|
||||
parser.add_argument('--only-confident', action='store_true', help='Drop rows that do not meet confidence thresholds')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load inputs
|
||||
if args.input_csv:
|
||||
if not os.path.exists(args.input_csv):
|
||||
raise SystemExit(f"Input CSV not found: {args.input_csv}")
|
||||
df = pd.read_csv(args.input_csv)
|
||||
if args.text_col not in df.columns:
|
||||
raise SystemExit(f"Text column '{args.text_col}' not in {args.input_csv}")
|
||||
df = df.copy()
|
||||
df['message'] = df[args.text_col].astype(str)
|
||||
base_cols = [c for c in ['id', 'date', 'message', 'url'] if c in df.columns]
|
||||
df = df[base_cols if base_cols else ['message']]
|
||||
df = df[df['message'].str.len() >= args.min_length]
|
||||
df = df.drop_duplicates(subset=['message']).reset_index(drop=True)
|
||||
else:
|
||||
df = _combine_inputs(args.posts_csv, args.replies_csv, text_col=args.text_col, min_length=args.min_length)
|
||||
|
||||
if args.limit and len(df) > args.limit:
|
||||
df = df.head(args.limit)
|
||||
|
||||
texts = df['message'].astype(str).tolist()
|
||||
|
||||
# Predict labels + confidence
|
||||
if args.backend == 'vader':
|
||||
labels, conf = _auto_label_vader(texts, pos_th=args.vader_pos, neg_th=args.vader_neg, min_margin=args.vader_margin)
|
||||
# For VADER, define acceptance: confident if outside neutral band by at least margin, or inside band with closeness to 0 below threshold
|
||||
accept = []
|
||||
analyzer = SentimentIntensityAnalyzer()
|
||||
for t in texts:
|
||||
comp = analyzer.polarity_scores(t or '').get('compound')
|
||||
if comp is None:
|
||||
accept.append(False)
|
||||
continue
|
||||
comp = float(comp)
|
||||
if comp >= args.vader_pos + args.vader_margin or comp <= args.vader_neg - args.vader_margin:
|
||||
accept.append(True)
|
||||
else:
|
||||
# inside or near band -> consider less confident
|
||||
accept.append(False)
|
||||
elif args.backend == 'transformers':
|
||||
labels, conf = _auto_label_transformers(texts, args.transformers_model, args.batch_size, args.min_prob, args.min_margin)
|
||||
accept = [((c >= 1.0)) or ((c >= 0.5)) for c in conf] # normalize conf ~[0,1]; accept medium-high confidence
|
||||
else:
|
||||
# GPT backend via Ollama: expect label+confidence
|
||||
try:
|
||||
from .gpt_sentiment import GPTSentiment
|
||||
except Exception:
|
||||
from gpt_sentiment import GPTSentiment
|
||||
clf = GPTSentiment(base_url=args.gpt_base_url, model=args.gpt_model)
|
||||
labels, conf = clf.predict_label_conf_batch(texts, batch_size=args.gpt_batch_size)
|
||||
# Accept medium-high confidence; simple threshold like transformers path
|
||||
accept = [c >= 0.5 for c in conf]
|
||||
|
||||
out = df.copy()
|
||||
out.insert(1, 'label', labels)
|
||||
out['confidence'] = conf
|
||||
|
||||
if args.only_confident:
|
||||
out = out[np.array(accept, dtype=bool)]
|
||||
out = out.reset_index(drop=True)
|
||||
|
||||
if args.label_format == 'int':
|
||||
out['label'] = _map_label_str_to_int(out['label'].astype(str).tolist())
|
||||
|
||||
os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True)
|
||||
out.to_csv(args.output, index=False)
|
||||
kept = len(out)
|
||||
print(f"Wrote {kept} labeled rows to {args.output} using backend={args.backend}")
|
||||
if args.only_confident:
|
||||
print("Note: only confident predictions were kept. You can remove --only-confident to include all rows.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
48
src/eval_sentiment.py
Normal file
48
src/eval_sentiment.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report
|
||||
|
||||
try:
|
||||
from .transformer_sentiment import TransformerSentiment
|
||||
except ImportError:
|
||||
# Allow running as a script via -m src.eval_sentiment
|
||||
from transformer_sentiment import TransformerSentiment
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Evaluate a fine-tuned transformers sentiment model on a labeled CSV')
|
||||
parser.add_argument('--csv', required=True, help='Labeled CSV path with message and label columns')
|
||||
parser.add_argument('--text-col', default='message')
|
||||
parser.add_argument('--label-col', default='label')
|
||||
parser.add_argument('--model', required=True, help='Model name or local path')
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
args = parser.parse_args()
|
||||
|
||||
df = pd.read_csv(args.csv)
|
||||
df = df[[args.text_col, args.label_col]].dropna().copy()
|
||||
texts = df[args.text_col].astype(str).tolist()
|
||||
true_labels = df[args.label_col].astype(str).tolist()
|
||||
|
||||
clf = TransformerSentiment(args.model)
|
||||
_, pred_labels = clf.predict_probs_and_labels(texts, batch_size=args.batch_size)
|
||||
|
||||
y_true = np.array(true_labels)
|
||||
y_pred = np.array(pred_labels)
|
||||
|
||||
# If labels differ from model id2label names, normalize to strings for comparison
|
||||
acc = accuracy_score(y_true, y_pred)
|
||||
f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
||||
prec_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
|
||||
rec_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
|
||||
|
||||
print('Accuracy:', f"{acc:.4f}")
|
||||
print('F1 (macro):', f"{f1_macro:.4f}")
|
||||
print('Precision (macro):', f"{prec_macro:.4f}")
|
||||
print('Recall (macro):', f"{rec_macro:.4f}")
|
||||
print('\nClassification report:')
|
||||
print(classification_report(y_true, y_pred, zero_division=0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
131
src/fetch_schedule.py
Normal file
131
src/fetch_schedule.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
API_BASE = "https://api.football-data.org/v4"
|
||||
COMPETITION_CODE = "PL" # Premier League
|
||||
|
||||
|
||||
def iso_date(d: str) -> str:
|
||||
# Accept YYYY-MM-DD and return ISO date
|
||||
try:
|
||||
return datetime.fromisoformat(d).date().isoformat()
|
||||
except Exception as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid date: {d}. Use YYYY-MM-DD") from e
|
||||
|
||||
|
||||
def fetch_matches(start_date: str, end_date: str, token: str) -> Dict[str, Any]:
|
||||
url = f"{API_BASE}/competitions/{COMPETITION_CODE}/matches"
|
||||
headers = {"X-Auth-Token": token}
|
||||
params = {
|
||||
"dateFrom": start_date,
|
||||
"dateTo": end_date,
|
||||
}
|
||||
r = requests.get(url, headers=headers, params=params, timeout=30)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def normalize_match(m: Dict[str, Any]) -> Dict[str, Any]:
|
||||
utc_date = m.get("utcDate")
|
||||
# Convert to date/time strings
|
||||
kick_iso = None
|
||||
if utc_date:
|
||||
try:
|
||||
kick_iso = datetime.fromisoformat(utc_date.replace("Z", "+00:00")).isoformat()
|
||||
except Exception:
|
||||
kick_iso = utc_date
|
||||
score = m.get("score", {})
|
||||
full_time = score.get("fullTime", {})
|
||||
|
||||
return {
|
||||
"id": m.get("id"),
|
||||
"status": m.get("status"),
|
||||
"matchday": m.get("matchday"),
|
||||
"utcDate": kick_iso,
|
||||
"homeTeam": (m.get("homeTeam") or {}).get("name"),
|
||||
"awayTeam": (m.get("awayTeam") or {}).get("name"),
|
||||
"homeScore": full_time.get("home"),
|
||||
"awayScore": full_time.get("away"),
|
||||
"referees": ", ".join([r.get("name", "") for r in m.get("referees", []) if r.get("name")]),
|
||||
"venue": m.get("area", {}).get("name"),
|
||||
"competition": (m.get("competition") or {}).get("name"),
|
||||
"stage": m.get("stage"),
|
||||
"group": m.get("group"),
|
||||
"link": m.get("id") and f"https://www.football-data.org/match/{m['id']}" or None,
|
||||
}
|
||||
|
||||
|
||||
def save_csv(matches: List[Dict[str, Any]], out_path: str) -> None:
|
||||
if not matches:
|
||||
# Write header only
|
||||
fields = [
|
||||
"id",
|
||||
"status",
|
||||
"matchday",
|
||||
"utcDate",
|
||||
"homeTeam",
|
||||
"awayTeam",
|
||||
"homeScore",
|
||||
"awayScore",
|
||||
"referees",
|
||||
"venue",
|
||||
"competition",
|
||||
"stage",
|
||||
"group",
|
||||
"link",
|
||||
]
|
||||
with open(out_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
return
|
||||
fields = list(matches[0].keys())
|
||||
with open(out_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
writer.writerows(matches)
|
||||
|
||||
|
||||
def save_json(matches: List[Dict[str, Any]], out_path: str) -> None:
|
||||
import json
|
||||
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(matches, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Fetch Premier League fixtures in a date range and save to CSV/JSON")
|
||||
parser.add_argument("--start-date", required=True, type=iso_date, help="YYYY-MM-DD (inclusive)")
|
||||
parser.add_argument("--end-date", required=True, type=iso_date, help="YYYY-MM-DD (inclusive)")
|
||||
parser.add_argument("-o", "--output", required=True, help="Output file path (.csv or .json)")
|
||||
args = parser.parse_args()
|
||||
|
||||
load_dotenv()
|
||||
token = os.getenv("FOOTBALL_DATA_API_TOKEN")
|
||||
if not token:
|
||||
raise SystemExit("Missing FOOTBALL_DATA_API_TOKEN in environment (.env)")
|
||||
|
||||
data = fetch_matches(args.start_date, args.end_date, token)
|
||||
matches_raw = data.get("matches", [])
|
||||
matches = [normalize_match(m) for m in matches_raw]
|
||||
|
||||
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
||||
|
||||
ext = os.path.splitext(args.output)[1].lower()
|
||||
if ext == ".csv":
|
||||
save_csv(matches, args.output)
|
||||
elif ext == ".json":
|
||||
save_json(matches, args.output)
|
||||
else:
|
||||
raise SystemExit("Output must end with .csv or .json")
|
||||
|
||||
print(f"Saved {len(matches)} matches to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
93
src/gpt_sentiment.py
Normal file
93
src/gpt_sentiment.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class GPTSentiment:
|
||||
"""
|
||||
Minimal client for a local GPT model served by Ollama.
|
||||
|
||||
Expects the model to respond with a strict JSON object like:
|
||||
{"label": "neg|neu|pos", "confidence": 0.0..1.0}
|
||||
|
||||
Endpoint used: POST {base_url}/api/generate with payload:
|
||||
{"model": <model>, "prompt": <prompt>, "stream": false, "format": "json"}
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:11434", model: str = "llama3", timeout: int = 30):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
|
||||
def _build_prompt(self, text: str) -> str:
|
||||
# Keep the instruction terse and deterministic; request strict JSON.
|
||||
return (
|
||||
"You are a strict JSON generator for sentiment analysis. "
|
||||
"Classify the INPUT text as one of: neg, neu, pos. "
|
||||
"Return ONLY a JSON object with keys 'label' and 'confidence' (0..1). "
|
||||
"No markdown, no prose.\n\n"
|
||||
f"INPUT: {text}"
|
||||
)
|
||||
|
||||
def _call(self, prompt: str) -> dict:
|
||||
url = f"{self.base_url}/api/generate"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
}
|
||||
r = requests.post(url, json=payload, timeout=self.timeout)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
# Ollama returns the model's response under 'response'
|
||||
raw = data.get("response", "").strip()
|
||||
try:
|
||||
obj = json.loads(raw)
|
||||
except Exception:
|
||||
# Try to recover simple cases by stripping codefences
|
||||
raw2 = raw.strip().removeprefix("```").removesuffix("```")
|
||||
obj = json.loads(raw2)
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _canonical_label(s: str) -> str:
|
||||
s = (s or "").strip().lower()
|
||||
if "neg" in s:
|
||||
return "neg"
|
||||
if "neu" in s or "neutral" in s:
|
||||
return "neu"
|
||||
if "pos" in s or "positive" in s:
|
||||
return "pos"
|
||||
return s or "neu"
|
||||
|
||||
@staticmethod
|
||||
def _compound_from_label_conf(label: str, confidence: float) -> float:
|
||||
label = GPTSentiment._canonical_label(label)
|
||||
c = max(0.0, min(1.0, float(confidence or 0.0)))
|
||||
if label == "pos":
|
||||
return c
|
||||
if label == "neg":
|
||||
return -c
|
||||
return 0.0
|
||||
|
||||
def predict_label_conf_batch(self, texts: List[str], batch_size: int = 8) -> Tuple[List[str], List[float]]:
|
||||
labels: List[str] = []
|
||||
confs: List[float] = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
for t in batch:
|
||||
try:
|
||||
obj = self._call(self._build_prompt(t))
|
||||
lab = self._canonical_label(obj.get("label", ""))
|
||||
conf = float(obj.get("confidence", 0.0))
|
||||
except Exception:
|
||||
lab, conf = "neu", 0.0
|
||||
labels.append(lab)
|
||||
confs.append(conf)
|
||||
return labels, confs
|
||||
|
||||
def predict_compound_batch(self, texts: List[str], batch_size: int = 8) -> List[float]:
|
||||
labels, confs = self.predict_label_conf_batch(texts, batch_size=batch_size)
|
||||
return [self._compound_from_label_conf(lab, conf) for lab, conf in zip(labels, confs)]
|
||||
65
src/make_labeling_set.py
Normal file
65
src/make_labeling_set.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def load_messages(csv_path: str, text_col: str = 'message', extra_cols=None) -> pd.DataFrame:
|
||||
if not os.path.exists(csv_path):
|
||||
return pd.DataFrame()
|
||||
df = pd.read_csv(csv_path)
|
||||
if text_col not in df.columns:
|
||||
return pd.DataFrame()
|
||||
cols = ['id', text_col, 'date']
|
||||
if extra_cols:
|
||||
for c in extra_cols:
|
||||
if c in df.columns:
|
||||
cols.append(c)
|
||||
cols = [c for c in cols if c in df.columns]
|
||||
out = df[cols].copy()
|
||||
out.rename(columns={text_col: 'message'}, inplace=True)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Create a labeling CSV from posts and/or replies.')
|
||||
parser.add_argument('--posts-csv', required=False, help='Posts CSV path (e.g., data/..._update.csv)')
|
||||
parser.add_argument('--replies-csv', required=False, help='Replies CSV path')
|
||||
parser.add_argument('-o', '--output', default='data/labeled_sentiment.csv', help='Output CSV for labeling')
|
||||
parser.add_argument('--sample-size', type=int, default=1000, help='Total rows to include (after combining)')
|
||||
parser.add_argument('--min-length', type=int, default=3, help='Minimum message length to include')
|
||||
parser.add_argument('--shuffle', action='store_true', help='Shuffle before sampling (default true)')
|
||||
parser.add_argument('--no-shuffle', dest='shuffle', action='store_false')
|
||||
parser.set_defaults(shuffle=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
frames = []
|
||||
if args.posts_csv:
|
||||
frames.append(load_messages(args.posts_csv))
|
||||
if args.replies_csv:
|
||||
# For replies, include parent_id if present
|
||||
r = load_messages(args.replies_csv, extra_cols=['parent_id'])
|
||||
frames.append(r)
|
||||
if not frames:
|
||||
raise SystemExit('No input CSVs provided. Use --posts-csv and/or --replies-csv.')
|
||||
|
||||
df = pd.concat(frames, ignore_index=True)
|
||||
# Basic filtering: non-empty text, min length, drop duplicates by message text
|
||||
df['message'] = df['message'].fillna('').astype(str)
|
||||
df = df[df['message'].str.len() >= args.min_length]
|
||||
df = df.drop_duplicates(subset=['message']).reset_index(drop=True)
|
||||
|
||||
if args.shuffle:
|
||||
df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)
|
||||
if args.sample_size and len(df) > args.sample_size:
|
||||
df = df.head(args.sample_size)
|
||||
|
||||
# Add blank label column for human annotation
|
||||
df.insert(1, 'label', '')
|
||||
|
||||
os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True)
|
||||
df.to_csv(args.output, index=False)
|
||||
print(f"Wrote labeling CSV with {len(df)} rows to {args.output}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
137
src/plot_labeled.py
Normal file
137
src/plot_labeled.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def safe_read(path: str) -> pd.DataFrame:
|
||||
if not os.path.exists(path):
|
||||
raise SystemExit(f"Input labeled CSV not found: {path}")
|
||||
df = pd.read_csv(path)
|
||||
if 'label' not in df.columns:
|
||||
raise SystemExit("Expected a 'label' column in the labeled CSV")
|
||||
if 'message' in df.columns:
|
||||
df['message'] = df['message'].fillna('').astype(str)
|
||||
if 'confidence' in df.columns:
|
||||
df['confidence'] = pd.to_numeric(df['confidence'], errors='coerce')
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'], errors='coerce')
|
||||
return df
|
||||
|
||||
|
||||
def ensure_out_dir(out_dir: str) -> str:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
return out_dir
|
||||
|
||||
|
||||
def plot_all(df: pd.DataFrame, out_dir: str) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
sns.set_style('whitegrid')
|
||||
|
||||
out_dir = ensure_out_dir(out_dir)
|
||||
|
||||
# 1) Class distribution
|
||||
try:
|
||||
plt.figure(figsize=(6,4))
|
||||
ax = (df['label'].astype(str).str.lower().value_counts()
|
||||
.reindex(['neg','neu','pos'])
|
||||
.fillna(0)
|
||||
.rename_axis('label').reset_index(name='count')
|
||||
.set_index('label')
|
||||
.plot(kind='bar', legend=False, color=['#d62728','#aaaaaa','#2ca02c']))
|
||||
plt.title('Labeled class distribution')
|
||||
plt.ylabel('Count')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(out_dir, 'labeled_class_distribution.png')
|
||||
plt.savefig(path, dpi=150)
|
||||
plt.close()
|
||||
print(f"[plots] Saved {path}")
|
||||
except Exception as e:
|
||||
print(f"[plots] Skipped class distribution: {e}")
|
||||
|
||||
# 2) Confidence histogram (overall)
|
||||
if 'confidence' in df.columns and df['confidence'].notna().any():
|
||||
try:
|
||||
plt.figure(figsize=(6,4))
|
||||
sns.histplot(df['confidence'].dropna(), bins=30, color='#1f77b4')
|
||||
plt.title('Confidence distribution (overall)')
|
||||
plt.xlabel('Confidence'); plt.ylabel('Frequency')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(out_dir, 'labeled_confidence_hist.png')
|
||||
plt.savefig(path, dpi=150); plt.close()
|
||||
print(f"[plots] Saved {path}")
|
||||
except Exception as e:
|
||||
print(f"[plots] Skipped confidence histogram: {e}")
|
||||
|
||||
# 3) Confidence by label (boxplot)
|
||||
try:
|
||||
plt.figure(figsize=(6,4))
|
||||
t = df[['label','confidence']].dropna()
|
||||
t['label'] = t['label'].astype(str).str.lower()
|
||||
order = ['neg','neu','pos']
|
||||
sns.boxplot(data=t, x='label', y='confidence', order=order, palette=['#d62728','#aaaaaa','#2ca02c'])
|
||||
plt.title('Confidence by label')
|
||||
plt.xlabel('Label'); plt.ylabel('Confidence')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(out_dir, 'labeled_confidence_by_label.png')
|
||||
plt.savefig(path, dpi=150); plt.close()
|
||||
print(f"[plots] Saved {path}")
|
||||
except Exception as e:
|
||||
print(f"[plots] Skipped confidence by label: {e}")
|
||||
|
||||
# 4) Message length by label
|
||||
if 'message' in df.columns:
|
||||
try:
|
||||
t = df[['label','message']].copy()
|
||||
t['label'] = t['label'].astype(str).str.lower()
|
||||
t['len'] = t['message'].astype(str).str.len()
|
||||
plt.figure(figsize=(6,4))
|
||||
sns.boxplot(data=t, x='label', y='len', order=['neg','neu','pos'], palette=['#d62728','#aaaaaa','#2ca02c'])
|
||||
plt.title('Message length by label')
|
||||
plt.xlabel('Label'); plt.ylabel('Length (chars)')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(out_dir, 'labeled_length_by_label.png')
|
||||
plt.savefig(path, dpi=150); plt.close()
|
||||
print(f"[plots] Saved {path}")
|
||||
except Exception as e:
|
||||
print(f"[plots] Skipped length by label: {e}")
|
||||
|
||||
# 5) Daily counts per label (if date present)
|
||||
if 'date' in df.columns and df['date'].notna().any():
|
||||
try:
|
||||
t = df[['date','label']].dropna().copy()
|
||||
t['day'] = pd.to_datetime(t['date'], errors='coerce').dt.date
|
||||
t['label'] = t['label'].astype(str).str.lower()
|
||||
pv = t.pivot_table(index='day', columns='label', values='date', aggfunc='count').fillna(0)
|
||||
# ensure consistent column order
|
||||
for c in ['neg','neu','pos']:
|
||||
if c not in pv.columns:
|
||||
pv[c] = 0
|
||||
pv = pv[['neg','neu','pos']]
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(10,4))
|
||||
pv.plot(kind='bar', stacked=True, color=['#d62728','#aaaaaa','#2ca02c'])
|
||||
plt.title('Daily labeled counts (stacked)')
|
||||
plt.xlabel('Day'); plt.ylabel('Count')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(out_dir, 'labeled_daily_counts.png')
|
||||
plt.savefig(path, dpi=150); plt.close()
|
||||
print(f"[plots] Saved {path}")
|
||||
except Exception as e:
|
||||
print(f"[plots] Skipped daily counts: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Plot graphs from labeled sentiment data.')
|
||||
parser.add_argument('-i', '--input', default='data/labeled_sentiment.csv', help='Path to labeled CSV')
|
||||
parser.add_argument('-o', '--out-dir', default='data', help='Output directory for plots')
|
||||
args = parser.parse_args()
|
||||
|
||||
df = safe_read(args.input)
|
||||
plot_all(df, args.out_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
749
src/telegram_scraper.py
Normal file
749
src/telegram_scraper.py
Normal file
@@ -0,0 +1,749 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncIterator, Iterable, Optional, Sequence, Set, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from telethon import TelegramClient
|
||||
from telethon.errors import SessionPasswordNeededError
|
||||
from telethon.errors.rpcerrorlist import MsgIdInvalidError, FloodWaitError
|
||||
from telethon.tl.functions.messages import GetDiscussionMessageRequest
|
||||
from telethon.tl.custom.message import Message
|
||||
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScrapedMessage:
|
||||
id: int
|
||||
date: Optional[str] # ISO format
|
||||
message: Optional[str]
|
||||
sender_id: Optional[int]
|
||||
views: Optional[int]
|
||||
forwards: Optional[int]
|
||||
replies: Optional[int]
|
||||
url: Optional[str]
|
||||
|
||||
|
||||
def to_iso(dt: datetime) -> str:
|
||||
return dt.replace(tzinfo=None).isoformat()
|
||||
|
||||
|
||||
async def iter_messages(
|
||||
client: TelegramClient,
|
||||
entity: str,
|
||||
limit: Optional[int] = None,
|
||||
offset_date: Optional[datetime] = None,
|
||||
) -> AsyncIterator[Message]:
|
||||
async for msg in client.iter_messages(entity, limit=limit, offset_date=offset_date):
|
||||
yield msg
|
||||
|
||||
|
||||
def message_to_record(msg: Message, channel_username: str) -> ScrapedMessage:
|
||||
return ScrapedMessage(
|
||||
id=msg.id,
|
||||
date=to_iso(msg.date) if msg.date else None,
|
||||
message=msg.message,
|
||||
sender_id=getattr(msg.sender_id, 'value', msg.sender_id) if hasattr(msg, 'sender_id') else None,
|
||||
views=getattr(msg, 'views', None),
|
||||
forwards=getattr(msg, 'forwards', None),
|
||||
replies=(msg.replies.replies if getattr(msg, 'replies', None) else None),
|
||||
url=f"https://t.me/{channel_username}/{msg.id}" if channel_username else None,
|
||||
)
|
||||
|
||||
|
||||
async def ensure_login(client: TelegramClient, phone: Optional[str] = None, twofa_password: Optional[str] = None):
|
||||
# Connect and log in, prompting interactively if needed
|
||||
await client.connect()
|
||||
if not await client.is_user_authorized():
|
||||
if not phone:
|
||||
phone = input("Enter your phone number (with country code): ")
|
||||
await client.send_code_request(phone)
|
||||
code = input("Enter the login code you received: ")
|
||||
try:
|
||||
await client.sign_in(phone=phone, code=code)
|
||||
except SessionPasswordNeededError:
|
||||
if twofa_password is None:
|
||||
twofa_password = input("Two-step verification enabled. Enter your password: ")
|
||||
await client.sign_in(password=twofa_password)
|
||||
|
||||
|
||||
async def scrape_channel(
|
||||
channel: str,
|
||||
output: str,
|
||||
limit: Optional[int] = None,
|
||||
offset_date: Optional[str] = None, # deprecated in favor of start_date
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
append: bool = False,
|
||||
session_name: str = "telegram",
|
||||
phone: Optional[str] = None,
|
||||
twofa_password: Optional[str] = None,
|
||||
):
|
||||
load_dotenv()
|
||||
api_id = os.getenv("TELEGRAM_API_ID")
|
||||
api_hash = os.getenv("TELEGRAM_API_HASH")
|
||||
session_name = os.getenv("TELEGRAM_SESSION_NAME", session_name)
|
||||
|
||||
if not api_id or not api_hash:
|
||||
raise RuntimeError("Missing TELEGRAM_API_ID/TELEGRAM_API_HASH in environment. See .env.example")
|
||||
|
||||
# Some providers store api_id as string; Telethon expects int
|
||||
try:
|
||||
api_id_int = int(api_id)
|
||||
except Exception as e:
|
||||
raise RuntimeError("TELEGRAM_API_ID must be an integer") from e
|
||||
|
||||
client = TelegramClient(session_name, api_id_int, api_hash)
|
||||
|
||||
# Parse date filters
|
||||
parsed_start = None
|
||||
parsed_end = None
|
||||
if start_date:
|
||||
parsed_start = datetime.fromisoformat(start_date)
|
||||
elif offset_date: # backward compatibility
|
||||
parsed_start = datetime.fromisoformat(offset_date)
|
||||
if end_date:
|
||||
parsed_end = datetime.fromisoformat(end_date)
|
||||
|
||||
await ensure_login(client, phone=phone, twofa_password=twofa_password)
|
||||
|
||||
# Determine output format based on extension
|
||||
ext = os.path.splitext(output)[1].lower()
|
||||
is_jsonl = ext in (".jsonl", ".ndjson")
|
||||
is_csv = ext == ".csv"
|
||||
|
||||
if not (is_jsonl or is_csv):
|
||||
raise ValueError("Output file must end with .jsonl or .csv")
|
||||
|
||||
# Prepare output writers
|
||||
csv_file = None
|
||||
csv_writer = None
|
||||
jsonl_file = None
|
||||
if is_csv:
|
||||
import csv
|
||||
mode = "a" if append else "w"
|
||||
csv_file = open(output, mode, newline="", encoding="utf-8")
|
||||
csv_writer = csv.DictWriter(
|
||||
csv_file,
|
||||
fieldnames=[
|
||||
"id",
|
||||
"date",
|
||||
"message",
|
||||
"sender_id",
|
||||
"views",
|
||||
"forwards",
|
||||
"replies",
|
||||
"url",
|
||||
],
|
||||
)
|
||||
# Write header if not appending, or file is empty
|
||||
need_header = True
|
||||
try:
|
||||
if append and os.path.exists(output) and os.path.getsize(output) > 0:
|
||||
need_header = False
|
||||
except Exception:
|
||||
pass
|
||||
if need_header:
|
||||
csv_writer.writeheader()
|
||||
elif is_jsonl:
|
||||
# Open once; append or overwrite
|
||||
mode = "a" if append else "w"
|
||||
jsonl_file = open(output, mode, encoding="utf-8")
|
||||
|
||||
written = 0
|
||||
try:
|
||||
async for msg in iter_messages(client, channel, limit=None, offset_date=None):
|
||||
# Telethon returns tz-aware datetimes; normalize for comparison
|
||||
msg_dt = msg.date
|
||||
if msg_dt is not None:
|
||||
msg_dt = msg_dt.replace(tzinfo=None)
|
||||
|
||||
# Date range filter: include if within [parsed_start, parsed_end] (inclusive)
|
||||
if parsed_start and msg_dt and msg_dt < parsed_start:
|
||||
# Since we're iterating newest-first, once older than start we can stop
|
||||
break
|
||||
if parsed_end and msg_dt and msg_dt > parsed_end:
|
||||
continue
|
||||
|
||||
rec = message_to_record(msg, channel_username=channel.lstrip("@"))
|
||||
if is_jsonl and jsonl_file is not None:
|
||||
jsonl_file.write(json.dumps(asdict(rec), ensure_ascii=False) + "\n")
|
||||
else:
|
||||
csv_writer.writerow(asdict(rec)) # type: ignore
|
||||
written += 1
|
||||
if limit is not None and written >= limit:
|
||||
break
|
||||
finally:
|
||||
if csv_file:
|
||||
csv_file.close()
|
||||
if jsonl_file:
|
||||
jsonl_file.close()
|
||||
await client.disconnect()
|
||||
|
||||
return written
|
||||
|
||||
|
||||
async def fetch_replies(
|
||||
channel: str,
|
||||
parent_ids: Sequence[int],
|
||||
output_csv: str,
|
||||
append: bool = False,
|
||||
session_name: str = "telegram",
|
||||
phone: Optional[str] = None,
|
||||
twofa_password: Optional[str] = None,
|
||||
concurrency: int = 5,
|
||||
existing_pairs: Optional[Set[Tuple[int, int]]] = None,
|
||||
):
|
||||
load_dotenv()
|
||||
api_id = os.getenv("TELEGRAM_API_ID")
|
||||
api_hash = os.getenv("TELEGRAM_API_HASH")
|
||||
session_name = os.getenv("TELEGRAM_SESSION_NAME", session_name)
|
||||
|
||||
if not api_id or not api_hash:
|
||||
raise RuntimeError("Missing TELEGRAM_API_ID/TELEGRAM_API_HASH in environment. See .env.example")
|
||||
client = TelegramClient(session_name, int(api_id), api_hash)
|
||||
await ensure_login(client, phone=phone, twofa_password=twofa_password)
|
||||
|
||||
import csv
|
||||
|
||||
# Rate limiting counters
|
||||
flood_hits = 0
|
||||
flood_wait_seconds = 0
|
||||
|
||||
analyzer = SentimentIntensityAnalyzer()
|
||||
os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
|
||||
mode = "a" if append else "w"
|
||||
with open(output_csv, mode, newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(
|
||||
f,
|
||||
fieldnames=["parent_id", "id", "date", "message", "sender_id", "sentiment_compound", "url"],
|
||||
)
|
||||
# Write header only if not appending or file empty
|
||||
need_header = True
|
||||
try:
|
||||
if append and os.path.exists(output_csv) and os.path.getsize(output_csv) > 0:
|
||||
need_header = False
|
||||
except Exception:
|
||||
pass
|
||||
if need_header:
|
||||
writer.writeheader()
|
||||
|
||||
write_lock = asyncio.Lock()
|
||||
sem = asyncio.Semaphore(max(1, int(concurrency)))
|
||||
|
||||
async def handle_parent(pid: int) -> List[dict]:
|
||||
rows: List[dict] = []
|
||||
# First try replies within the same channel (works for groups/supergroups)
|
||||
attempts = 0
|
||||
while attempts < 3:
|
||||
try:
|
||||
async for reply in client.iter_messages(channel, reply_to=pid):
|
||||
dt = reply.date.replace(tzinfo=None) if reply.date else None
|
||||
url = f"https://t.me/{channel.lstrip('@')}/{reply.id}" if reply.id else None
|
||||
text = reply.message or ""
|
||||
sent = analyzer.polarity_scores(text).get("compound")
|
||||
rows.append(
|
||||
{
|
||||
"parent_id": pid,
|
||||
"id": reply.id,
|
||||
"date": to_iso(dt) if dt else None,
|
||||
"message": text,
|
||||
"sender_id": getattr(reply, "sender_id", None),
|
||||
"sentiment_compound": sent,
|
||||
"url": url,
|
||||
}
|
||||
)
|
||||
break
|
||||
except FloodWaitError as e:
|
||||
secs = int(getattr(e, 'seconds', 5))
|
||||
flood_hits += 1
|
||||
flood_wait_seconds += secs
|
||||
print(f"[rate-limit] FloodWait while scanning replies in-channel for parent {pid}; waiting {secs}s", flush=True)
|
||||
await asyncio.sleep(secs + 1)
|
||||
attempts += 1
|
||||
continue
|
||||
except MsgIdInvalidError:
|
||||
# Likely a channel with a linked discussion group; fall back below
|
||||
rows.clear()
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
|
||||
if rows:
|
||||
return rows
|
||||
|
||||
# Fallback: for channels with comments in a linked discussion group
|
||||
try:
|
||||
res = await client(GetDiscussionMessageRequest(peer=channel, msg_id=pid))
|
||||
except Exception:
|
||||
# No discussion thread found or not accessible
|
||||
return rows
|
||||
|
||||
# Identify the discussion chat and the root message id in that chat
|
||||
disc_chat = None
|
||||
if getattr(res, "chats", None):
|
||||
# Prefer the first chat returned as the discussion chat
|
||||
disc_chat = res.chats[0]
|
||||
|
||||
disc_root_id = None
|
||||
for m in getattr(res, "messages", []) or []:
|
||||
try:
|
||||
peer_id = getattr(m, "peer_id", None)
|
||||
if not peer_id or not disc_chat:
|
||||
continue
|
||||
ch_id = getattr(peer_id, "channel_id", None) or getattr(peer_id, "chat_id", None)
|
||||
if ch_id == getattr(disc_chat, "id", None):
|
||||
disc_root_id = m.id
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not disc_chat or not disc_root_id:
|
||||
return rows
|
||||
|
||||
group_username = getattr(disc_chat, "username", None)
|
||||
attempts = 0
|
||||
while attempts < 3:
|
||||
try:
|
||||
async for reply in client.iter_messages(disc_chat, reply_to=disc_root_id):
|
||||
dt = reply.date.replace(tzinfo=None) if reply.date else None
|
||||
text = reply.message or ""
|
||||
sent = analyzer.polarity_scores(text).get("compound")
|
||||
# Construct URL only if the discussion group has a public username
|
||||
url = None
|
||||
if group_username and reply.id:
|
||||
url = f"https://t.me/{group_username}/{reply.id}"
|
||||
rows.append(
|
||||
{
|
||||
"parent_id": pid,
|
||||
"id": reply.id,
|
||||
"date": to_iso(dt) if dt else None,
|
||||
"message": text,
|
||||
"sender_id": getattr(reply, "sender_id", None),
|
||||
"sentiment_compound": sent,
|
||||
"url": url,
|
||||
}
|
||||
)
|
||||
break
|
||||
except FloodWaitError as e:
|
||||
secs = int(getattr(e, 'seconds', 5))
|
||||
flood_hits += 1
|
||||
flood_wait_seconds += secs
|
||||
print(f"[rate-limit] FloodWait while scanning discussion group for parent {pid}; waiting {secs}s", flush=True)
|
||||
await asyncio.sleep(secs + 1)
|
||||
attempts += 1
|
||||
continue
|
||||
except Exception:
|
||||
break
|
||||
return rows
|
||||
|
||||
total_written = 0
|
||||
processed = 0
|
||||
total = len(list(parent_ids)) if hasattr(parent_ids, '__len__') else None
|
||||
|
||||
async def worker(pid: int):
|
||||
nonlocal total_written, processed
|
||||
async with sem:
|
||||
rows = await handle_parent(int(pid))
|
||||
async with write_lock:
|
||||
if rows:
|
||||
# Dedupe against existing pairs if provided (resume mode)
|
||||
if existing_pairs is not None:
|
||||
filtered: List[dict] = []
|
||||
for r in rows:
|
||||
try:
|
||||
key = (int(r.get("parent_id")), int(r.get("id")))
|
||||
except Exception:
|
||||
continue
|
||||
if key in existing_pairs:
|
||||
continue
|
||||
existing_pairs.add(key)
|
||||
filtered.append(r)
|
||||
rows = filtered
|
||||
if rows:
|
||||
writer.writerows(rows)
|
||||
total_written += len(rows)
|
||||
processed += 1
|
||||
if processed % 10 == 0 or (rows and len(rows) > 0):
|
||||
if total is not None:
|
||||
print(f"[replies] processed {processed}/{total} parents; last parent {pid} wrote {len(rows)} replies; total replies {total_written}", flush=True)
|
||||
else:
|
||||
print(f"[replies] processed {processed} parents; last parent {pid} wrote {len(rows)} replies; total replies {total_written}", flush=True)
|
||||
|
||||
tasks = [asyncio.create_task(worker(pid)) for pid in parent_ids]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await client.disconnect()
|
||||
if flood_hits:
|
||||
print(f"[rate-limit] Summary: {flood_hits} FloodWait events; total waited ~{flood_wait_seconds}s", flush=True)
|
||||
|
||||
|
||||
async def fetch_forwards(
|
||||
channel: str,
|
||||
parent_ids: Set[int],
|
||||
output_csv: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
scan_limit: Optional[int] = None,
|
||||
append: bool = False,
|
||||
session_name: str = "telegram",
|
||||
phone: Optional[str] = None,
|
||||
twofa_password: Optional[str] = None,
|
||||
concurrency: int = 5,
|
||||
chunk_size: int = 1000,
|
||||
):
|
||||
"""Best-effort: find forwarded messages within the SAME channel that reference the given parent_ids.
|
||||
Telegram API does not provide a global reverse-lookup of forwards across all channels; we therefore scan
|
||||
this channel's history and collect messages with fwd_from.channel_post matching a parent id.
|
||||
"""
|
||||
load_dotenv()
|
||||
api_id = os.getenv("TELEGRAM_API_ID")
|
||||
api_hash = os.getenv("TELEGRAM_API_HASH")
|
||||
session_name = os.getenv("TELEGRAM_SESSION_NAME", session_name)
|
||||
if not api_id or not api_hash:
|
||||
raise RuntimeError("Missing TELEGRAM_API_ID/TELEGRAM_API_HASH in environment. See .env.example")
|
||||
client = TelegramClient(session_name, int(api_id), api_hash)
|
||||
await ensure_login(client, phone=phone, twofa_password=twofa_password)
|
||||
|
||||
import csv
|
||||
|
||||
# Rate limiting counters
|
||||
flood_hits = 0
|
||||
import csv
|
||||
|
||||
analyzer = SentimentIntensityAnalyzer()
|
||||
os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
|
||||
mode = "a" if append else "w"
|
||||
write_lock = asyncio.Lock()
|
||||
with open(output_csv, mode, newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(
|
||||
f,
|
||||
fieldnames=["parent_id", "id", "date", "message", "sender_id", "sentiment_compound", "url"],
|
||||
)
|
||||
need_header = True
|
||||
try:
|
||||
if append and os.path.exists(output_csv) and os.path.getsize(output_csv) > 0:
|
||||
need_header = False
|
||||
except Exception:
|
||||
pass
|
||||
if need_header:
|
||||
writer.writeheader()
|
||||
|
||||
parsed_start = datetime.fromisoformat(start_date) if start_date else None
|
||||
parsed_end = datetime.fromisoformat(end_date) if end_date else None
|
||||
|
||||
# If no scan_limit provided, fall back to sequential scan to avoid unbounded concurrency
|
||||
if scan_limit is None:
|
||||
scanned = 0
|
||||
matched = 0
|
||||
async for msg in client.iter_messages(channel, limit=None):
|
||||
dt = msg.date.replace(tzinfo=None) if msg.date else None
|
||||
if parsed_start and dt and dt < parsed_start:
|
||||
break
|
||||
if parsed_end and dt and dt > parsed_end:
|
||||
continue
|
||||
fwd = getattr(msg, "fwd_from", None)
|
||||
if not fwd:
|
||||
continue
|
||||
ch_post = getattr(fwd, "channel_post", None)
|
||||
if ch_post and int(ch_post) in parent_ids:
|
||||
text = msg.message or ""
|
||||
sent = analyzer.polarity_scores(text).get("compound")
|
||||
url = f"https://t.me/{channel.lstrip('@')}/{msg.id}" if msg.id else None
|
||||
writer.writerow(
|
||||
{
|
||||
"parent_id": int(ch_post),
|
||||
"id": msg.id,
|
||||
"date": to_iso(dt) if dt else None,
|
||||
"message": text,
|
||||
"sender_id": getattr(msg, "sender_id", None),
|
||||
"sentiment_compound": sent,
|
||||
"url": url,
|
||||
}
|
||||
)
|
||||
matched += 1
|
||||
scanned += 1
|
||||
if scanned % 1000 == 0:
|
||||
print(f"[forwards] scanned ~{scanned} messages; total forwards {matched}", flush=True)
|
||||
else:
|
||||
# Concurrent chunked scanning by id ranges
|
||||
# Rate limiting counters
|
||||
flood_hits = 0
|
||||
flood_wait_seconds = 0
|
||||
sem = asyncio.Semaphore(max(1, int(concurrency)))
|
||||
progress_lock = asyncio.Lock()
|
||||
matched_total = 0
|
||||
completed_chunks = 0
|
||||
|
||||
# Determine latest message id
|
||||
latest_msg = await client.get_messages(channel, limit=1)
|
||||
latest_id = None
|
||||
try:
|
||||
latest_id = getattr(latest_msg, 'id', None) or (latest_msg[0].id if latest_msg else None)
|
||||
except Exception:
|
||||
latest_id = None
|
||||
if not latest_id:
|
||||
await client.disconnect()
|
||||
return
|
||||
|
||||
total_chunks = max(1, (int(scan_limit) + int(chunk_size) - 1) // int(chunk_size))
|
||||
|
||||
async def process_chunk(idx: int):
|
||||
nonlocal flood_hits, flood_wait_seconds
|
||||
nonlocal matched_total, completed_chunks
|
||||
max_id = latest_id - idx * int(chunk_size)
|
||||
min_id = max(0, max_id - int(chunk_size))
|
||||
attempts = 0
|
||||
local_matches = 0
|
||||
while attempts < 3:
|
||||
try:
|
||||
async with sem:
|
||||
async for msg in client.iter_messages(channel, min_id=min_id, max_id=max_id):
|
||||
dt = msg.date.replace(tzinfo=None) if msg.date else None
|
||||
if parsed_start and dt and dt < parsed_start:
|
||||
# This range reached before start; skip remaining in this chunk
|
||||
break
|
||||
if parsed_end and dt and dt > parsed_end:
|
||||
continue
|
||||
fwd = getattr(msg, "fwd_from", None)
|
||||
if not fwd:
|
||||
continue
|
||||
ch_post = getattr(fwd, "channel_post", None)
|
||||
if ch_post and int(ch_post) in parent_ids:
|
||||
text = msg.message or ""
|
||||
sent = analyzer.polarity_scores(text).get("compound")
|
||||
url = f"https://t.me/{channel.lstrip('@')}/{msg.id}" if msg.id else None
|
||||
async with write_lock:
|
||||
writer.writerow(
|
||||
{
|
||||
"parent_id": int(ch_post),
|
||||
"id": msg.id,
|
||||
"date": to_iso(dt) if dt else None,
|
||||
"message": text,
|
||||
"sender_id": getattr(msg, "sender_id", None),
|
||||
"sentiment_compound": sent,
|
||||
"url": url,
|
||||
}
|
||||
)
|
||||
local_matches += 1
|
||||
break
|
||||
except FloodWaitError as e:
|
||||
secs = int(getattr(e, 'seconds', 5))
|
||||
flood_hits += 1
|
||||
flood_wait_seconds += secs
|
||||
print(f"[rate-limit] FloodWait while scanning ids {min_id}-{max_id}; waiting {secs}s", flush=True)
|
||||
await asyncio.sleep(secs + 1)
|
||||
attempts += 1
|
||||
continue
|
||||
except Exception:
|
||||
# best-effort; skip this chunk
|
||||
break
|
||||
async with progress_lock:
|
||||
matched_total += local_matches
|
||||
completed_chunks += 1
|
||||
print(
|
||||
f"[forwards] chunks {completed_chunks}/{total_chunks}; last {min_id}-{max_id} wrote {local_matches} forwards; total forwards {matched_total}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
tasks = [asyncio.create_task(process_chunk(i)) for i in range(total_chunks)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await client.disconnect()
|
||||
# Print summary if we used concurrent chunking
|
||||
try:
|
||||
if scan_limit is not None and 'flood_hits' in locals() and flood_hits:
|
||||
print(f"[rate-limit] Summary: {flood_hits} FloodWait events; total waited ~{flood_wait_seconds}s", flush=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Telegram scraper utilities")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# Subcommand: scrape channel history
|
||||
p_scrape = sub.add_parser("scrape", help="Scrape messages from a channel")
|
||||
p_scrape.add_argument("channel", help="Channel username or t.me link, e.g. @python, https://t.me/python")
|
||||
p_scrape.add_argument("--output", "-o", required=True, help="Output file (.jsonl or .csv)")
|
||||
p_scrape.add_argument("--limit", type=int, default=None, help="Max number of messages to save after filtering")
|
||||
p_scrape.add_argument("--offset-date", dest="offset_date", default=None, help="Deprecated: use --start-date instead. ISO date (inclusive)")
|
||||
p_scrape.add_argument("--start-date", dest="start_date", default=None, help="ISO start date (inclusive)")
|
||||
p_scrape.add_argument("--end-date", dest="end_date", default=None, help="ISO end date (inclusive)")
|
||||
p_scrape.add_argument("--append", action="store_true", help="Append to the output file instead of overwriting")
|
||||
p_scrape.add_argument("--session-name", default=os.getenv("TELEGRAM_SESSION_NAME", "telegram"))
|
||||
p_scrape.add_argument("--phone", default=None)
|
||||
p_scrape.add_argument("--twofa-password", default=os.getenv("TELEGRAM_2FA_PASSWORD"))
|
||||
|
||||
# Subcommand: fetch replies for specific message ids
|
||||
p_rep = sub.add_parser("replies", help="Fetch replies for given message IDs and save to CSV")
|
||||
p_rep.add_argument("channel", help="Channel username or t.me link")
|
||||
src = p_rep.add_mutually_exclusive_group(required=True)
|
||||
src.add_argument("--ids", help="Comma-separated parent message IDs")
|
||||
src.add_argument("--from-csv", dest="from_csv", help="Path to CSV with an 'id' column to use as parent IDs")
|
||||
p_rep.add_argument("--output", "-o", required=True, help="Output CSV path (e.g., data/replies_channel.csv)")
|
||||
p_rep.add_argument("--append", action="store_true", help="Append to the output file instead of overwriting")
|
||||
p_rep.add_argument("--session-name", default=os.getenv("TELEGRAM_SESSION_NAME", "telegram"))
|
||||
p_rep.add_argument("--phone", default=None)
|
||||
p_rep.add_argument("--twofa-password", default=os.getenv("TELEGRAM_2FA_PASSWORD"))
|
||||
p_rep.add_argument("--concurrency", type=int, default=5, help="Number of parent IDs to process in parallel (default 5)")
|
||||
p_rep.add_argument("--min-replies", type=int, default=None, help="When using --from-csv, only process parents with replies >= this value")
|
||||
p_rep.add_argument("--resume", action="store_true", help="Resume mode: skip parent_id,id pairs already present in the output CSV")
|
||||
|
||||
# Subcommand: fetch forwards (same-channel forwards referencing parent ids)
|
||||
p_fwd = sub.add_parser("forwards", help="Best-effort: find forwards within the same channel for given parent IDs")
|
||||
p_fwd.add_argument("channel", help="Channel username or t.me link")
|
||||
src2 = p_fwd.add_mutually_exclusive_group(required=True)
|
||||
src2.add_argument("--ids", help="Comma-separated parent message IDs")
|
||||
src2.add_argument("--from-csv", dest="from_csv", help="Path to CSV with an 'id' column to use as parent IDs")
|
||||
p_fwd.add_argument("--output", "-o", required=True, help="Output CSV path (e.g., data/forwards_channel.csv)")
|
||||
p_fwd.add_argument("--start-date", dest="start_date", default=None)
|
||||
p_fwd.add_argument("--end-date", dest="end_date", default=None)
|
||||
p_fwd.add_argument("--scan-limit", dest="scan_limit", type=int, default=None, help="Max messages to scan in channel history")
|
||||
p_fwd.add_argument("--concurrency", type=int, default=5, help="Number of id-chunks to scan in parallel (requires --scan-limit)")
|
||||
p_fwd.add_argument("--chunk-size", dest="chunk_size", type=int, default=1000, help="Approx. messages per chunk (ids)")
|
||||
p_fwd.add_argument("--append", action="store_true", help="Append to the output file instead of overwriting")
|
||||
p_fwd.add_argument("--session-name", default=os.getenv("TELEGRAM_SESSION_NAME", "telegram"))
|
||||
p_fwd.add_argument("--phone", default=None)
|
||||
p_fwd.add_argument("--twofa-password", default=os.getenv("TELEGRAM_2FA_PASSWORD"))
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Normalize channel
|
||||
channel = getattr(args, "channel", None)
|
||||
if channel and channel.startswith("https://t.me/"):
|
||||
channel = channel.replace("https://t.me/", "@")
|
||||
|
||||
def _normalize_handle(ch: Optional[str]) -> Optional[str]:
|
||||
if not ch:
|
||||
return ch
|
||||
# Expect inputs like '@name' or 'name'; return lowercase without leading '@'
|
||||
return ch.lstrip('@').lower()
|
||||
|
||||
def _extract_handle_from_url(url: str) -> Optional[str]:
|
||||
try:
|
||||
if not url:
|
||||
return None
|
||||
# Accept forms like https://t.me/Name/123 or http(s)://t.me/c/<id>/<msg>
|
||||
# Only public usernames (not /c/ links) can be compared reliably
|
||||
if "/t.me/" in url:
|
||||
# crude parse without urlparse to avoid dependency
|
||||
after = url.split("t.me/")[-1]
|
||||
parts = after.split('/')
|
||||
if parts and parts[0] and parts[0] != 'c':
|
||||
return parts[0]
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
if args.command == "scrape":
|
||||
written = asyncio.run(
|
||||
scrape_channel(
|
||||
channel=channel,
|
||||
output=args.output,
|
||||
limit=args.limit,
|
||||
offset_date=args.offset_date,
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
append=getattr(args, "append", False),
|
||||
session_name=args.session_name,
|
||||
phone=args.phone,
|
||||
twofa_password=args.twofa_password,
|
||||
)
|
||||
)
|
||||
print(f"Wrote {written} messages to {args.output}")
|
||||
elif args.command == "replies":
|
||||
# If using --from-csv, try to infer channel from URLs and warn on mismatch
|
||||
try:
|
||||
if getattr(args, 'from_csv', None):
|
||||
import pandas as _pd # local import to keep startup light
|
||||
# Read a small sample of URL column to detect handle
|
||||
sample = _pd.read_csv(args.from_csv, usecols=['url'], nrows=20)
|
||||
url_handles = [
|
||||
_extract_handle_from_url(str(u)) for u in sample['url'].dropna().tolist() if isinstance(u, (str,))
|
||||
]
|
||||
inferred = next((h for h in url_handles if h), None)
|
||||
provided = _normalize_handle(channel)
|
||||
if inferred and provided and _normalize_handle(inferred) != provided:
|
||||
print(
|
||||
f"[warning] CSV appears to be from @{_normalize_handle(inferred)} but you passed -c @{provided}. "
|
||||
f"Replies may be empty. Consider using -c https://t.me/{inferred}",
|
||||
flush=True,
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort only; ignore any issues reading/inspecting CSV
|
||||
pass
|
||||
parent_ids: Set[int]
|
||||
if getattr(args, "ids", None):
|
||||
parent_ids = {int(x.strip()) for x in args.ids.split(",") if x.strip()}
|
||||
else:
|
||||
import pandas as pd # local import
|
||||
usecols = ['id']
|
||||
if args.min_replies is not None:
|
||||
usecols.append('replies')
|
||||
df = pd.read_csv(args.from_csv, usecols=usecols)
|
||||
if args.min_replies is not None and 'replies' in df.columns:
|
||||
df = df[df['replies'].fillna(0).astype(int) >= int(args.min_replies)]
|
||||
parent_ids = set(int(x) for x in df['id'].dropna().astype(int).tolist())
|
||||
existing_pairs = None
|
||||
if args.resume and os.path.exists(args.output):
|
||||
try:
|
||||
import csv as _csv
|
||||
existing_pairs = set()
|
||||
with open(args.output, "r", encoding="utf-8") as _f:
|
||||
reader = _csv.DictReader(_f)
|
||||
for row in reader:
|
||||
try:
|
||||
existing_pairs.add((int(row.get("parent_id")), int(row.get("id"))))
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
existing_pairs = None
|
||||
|
||||
asyncio.run(
|
||||
fetch_replies(
|
||||
channel=channel,
|
||||
parent_ids=sorted(parent_ids),
|
||||
output_csv=args.output,
|
||||
append=getattr(args, "append", False),
|
||||
session_name=args.session_name,
|
||||
phone=args.phone,
|
||||
twofa_password=args.twofa_password,
|
||||
concurrency=max(1, int(getattr(args, 'concurrency', 5))),
|
||||
existing_pairs=existing_pairs,
|
||||
)
|
||||
)
|
||||
print(f"Saved replies to {args.output}")
|
||||
elif args.command == "forwards":
|
||||
parent_ids: Set[int]
|
||||
if getattr(args, "ids", None):
|
||||
parent_ids = {int(x.strip()) for x in args.ids.split(",") if x.strip()}
|
||||
else:
|
||||
import pandas as pd
|
||||
df = pd.read_csv(args.from_csv)
|
||||
parent_ids = set(int(x) for x in df['id'].dropna().astype(int).tolist())
|
||||
asyncio.run(
|
||||
fetch_forwards(
|
||||
channel=channel,
|
||||
parent_ids=parent_ids,
|
||||
output_csv=args.output,
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
scan_limit=args.scan_limit,
|
||||
concurrency=max(1, int(getattr(args, 'concurrency', 5))),
|
||||
chunk_size=max(1, int(getattr(args, 'chunk_size', 1000))),
|
||||
append=getattr(args, "append", False),
|
||||
session_name=args.session_name,
|
||||
phone=args.phone,
|
||||
twofa_password=args.twofa_password,
|
||||
)
|
||||
)
|
||||
print(f"Saved forwards to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
135
src/train_sentiment.py
Normal file
135
src/train_sentiment.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from datasets import Dataset, ClassLabel
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
||||
import inspect
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||
|
||||
|
||||
def build_dataset(df: pd.DataFrame, text_col: str, label_col: str, label_mapping: Optional[dict] = None) -> Dataset:
|
||||
d = df[[text_col, label_col]].dropna().copy()
|
||||
# Normalize and drop empty labels
|
||||
d[label_col] = d[label_col].astype(str).str.strip()
|
||||
d = d[d[label_col] != '']
|
||||
if d.empty:
|
||||
raise SystemExit("No labeled rows found. Please fill the 'label' column in your CSV (e.g., neg/neu/pos or 0/1/2).")
|
||||
if label_mapping:
|
||||
d[label_col] = d[label_col].map(label_mapping)
|
||||
# If labels are strings, factorize them
|
||||
if d[label_col].dtype == object:
|
||||
d[label_col] = d[label_col].astype('category')
|
||||
label2id = {k: int(v) for v, k in enumerate(d[label_col].cat.categories)}
|
||||
id2label = {v: k for k, v in label2id.items()}
|
||||
d[label_col] = d[label_col].cat.codes
|
||||
else:
|
||||
# Assume numeric 0..N-1
|
||||
classes = sorted(d[label_col].unique().tolist())
|
||||
label2id = {str(c): int(c) for c in classes}
|
||||
id2label = {int(c): str(c) for c in classes}
|
||||
hf = Dataset.from_pandas(d.reset_index(drop=True))
|
||||
hf = hf.class_encode_column(label_col)
|
||||
hf.features[label_col] = ClassLabel(num_classes=len(id2label), names=[id2label[i] for i in range(len(id2label))])
|
||||
return hf, label2id, id2label
|
||||
|
||||
|
||||
def tokenize_fn(examples, tokenizer, text_col):
|
||||
return tokenizer(examples[text_col], truncation=True, padding=False)
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
logits, labels = eval_pred
|
||||
preds = np.argmax(logits, axis=-1)
|
||||
return {
|
||||
'accuracy': accuracy_score(labels, preds),
|
||||
'precision_macro': precision_score(labels, preds, average='macro', zero_division=0),
|
||||
'recall_macro': recall_score(labels, preds, average='macro', zero_division=0),
|
||||
'f1_macro': f1_score(labels, preds, average='macro', zero_division=0),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Fine-tune a transformers model for sentiment.')
|
||||
parser.add_argument('--train-csv', required=True, help='Path to labeled CSV')
|
||||
parser.add_argument('--text-col', default='message', help='Text column name')
|
||||
parser.add_argument('--label-col', default='label', help='Label column name (e.g., pos/neu/neg or 2/1/0)')
|
||||
parser.add_argument('--model-name', default='distilbert-base-uncased', help='Base model name or path')
|
||||
parser.add_argument('--output-dir', default='models/sentiment-distilbert', help='Where to save the fine-tuned model')
|
||||
parser.add_argument('--epochs', type=int, default=3)
|
||||
parser.add_argument('--batch-size', type=int, default=16)
|
||||
parser.add_argument('--lr', type=float, default=5e-5)
|
||||
parser.add_argument('--eval-split', type=float, default=0.1, help='Fraction of data for eval')
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
df = pd.read_csv(args.train_csv)
|
||||
ds, label2id, id2label = build_dataset(df, args.text_col, args.label_col)
|
||||
if args.eval_split > 0:
|
||||
ds = ds.train_test_split(test_size=args.eval_split, seed=42, stratify_by_column=args.label_col)
|
||||
train_ds, eval_ds = ds['train'], ds['test']
|
||||
else:
|
||||
train_ds, eval_ds = ds, None
|
||||
|
||||
num_labels = len(id2label)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name,
|
||||
num_labels=num_labels,
|
||||
id2label=id2label,
|
||||
label2id={k: int(v) for k, v in label2id.items()},
|
||||
)
|
||||
|
||||
tokenized_train = train_ds.map(lambda x: tokenize_fn(x, tokenizer, args.text_col), batched=True)
|
||||
tokenized_eval = eval_ds.map(lambda x: tokenize_fn(x, tokenizer, args.text_col), batched=True) if (eval_ds is not None) else None
|
||||
|
||||
# Build TrainingArguments with compatibility across transformers versions
|
||||
base_kwargs = {
|
||||
'output_dir': args.output_dir,
|
||||
'per_device_train_batch_size': args.batch_size,
|
||||
'per_device_eval_batch_size': args.batch_size,
|
||||
'num_train_epochs': args.epochs,
|
||||
'learning_rate': args.lr,
|
||||
'fp16': False,
|
||||
'logging_steps': 50,
|
||||
}
|
||||
eval_kwargs = {}
|
||||
if tokenized_eval is not None:
|
||||
# Set both evaluation_strategy and eval_strategy for compatibility across transformers versions
|
||||
eval_kwargs.update({
|
||||
'evaluation_strategy': 'epoch',
|
||||
'eval_strategy': 'epoch',
|
||||
'save_strategy': 'epoch',
|
||||
'load_best_model_at_end': True,
|
||||
'metric_for_best_model': 'f1_macro',
|
||||
'greater_is_better': True,
|
||||
})
|
||||
|
||||
# Filter kwargs to only include parameters supported by this transformers version
|
||||
sig = inspect.signature(TrainingArguments.__init__)
|
||||
allowed = set(sig.parameters.keys())
|
||||
def _filter(d: dict) -> dict:
|
||||
return {k: v for k, v in d.items() if k in allowed}
|
||||
|
||||
training_args = TrainingArguments(**_filter(base_kwargs), **_filter(eval_kwargs))
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_train,
|
||||
eval_dataset=tokenized_eval,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics if tokenized_eval else None,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
print(f"Model saved to {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
90
src/transformer_sentiment.py
Normal file
90
src/transformer_sentiment.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
import torch
|
||||
|
||||
|
||||
class TransformerSentiment:
|
||||
def __init__(self, model_name_or_path: str, device: Optional[str] = None, max_length: int = 256):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
|
||||
self.max_length = max_length
|
||||
if device is None:
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
device = 'mps'
|
||||
else:
|
||||
device = 'cpu'
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Expect labels roughly like {0:'neg',1:'neu',2:'pos'} or similar
|
||||
self.id2label = self.model.config.id2label if hasattr(self.model.config, 'id2label') else {0:'0',1:'1',2:'2'}
|
||||
|
||||
def _compound_from_probs(self, probs: np.ndarray) -> float:
|
||||
# Map class probabilities to a [-1,1] compound-like score.
|
||||
# If we have exactly 3 labels and names look like neg/neu/pos (any case), use that mapping.
|
||||
labels = [self.id2label.get(i, str(i)).lower() for i in range(len(probs))]
|
||||
try:
|
||||
neg_idx = labels.index('neg') if 'neg' in labels else labels.index('negative')
|
||||
except ValueError:
|
||||
neg_idx = 0
|
||||
try:
|
||||
neu_idx = labels.index('neu') if 'neu' in labels else labels.index('neutral')
|
||||
except ValueError:
|
||||
neu_idx = 1 if len(probs) > 2 else None
|
||||
try:
|
||||
pos_idx = labels.index('pos') if 'pos' in labels else labels.index('positive')
|
||||
except ValueError:
|
||||
pos_idx = (len(probs)-1)
|
||||
|
||||
p_neg = float(probs[neg_idx]) if neg_idx is not None else 0.0
|
||||
p_pos = float(probs[pos_idx]) if pos_idx is not None else 0.0
|
||||
# A simple skew: pos - neg; keep within [-1,1]
|
||||
comp = max(-1.0, min(1.0, p_pos - p_neg))
|
||||
return comp
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_compound_batch(self, texts: List[str], batch_size: int = 32) -> List[float]:
|
||||
out: List[float] = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
enc = self.tokenizer(
|
||||
batch,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
enc = {k: v.to(self.device) for k, v in enc.items()}
|
||||
logits = self.model(**enc).logits
|
||||
probs = torch.softmax(logits, dim=-1).cpu().numpy()
|
||||
for row in probs:
|
||||
out.append(self._compound_from_probs(row))
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_probs_and_labels(self, texts: List[str], batch_size: int = 32):
|
||||
probs_all = []
|
||||
labels_all: List[str] = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
enc = self.tokenizer(
|
||||
batch,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
enc = {k: v.to(self.device) for k, v in enc.items()}
|
||||
logits = self.model(**enc).logits
|
||||
probs = torch.softmax(logits, dim=-1).cpu().numpy()
|
||||
preds = probs.argmax(axis=-1)
|
||||
for j, row in enumerate(probs):
|
||||
probs_all.append(row)
|
||||
label = self.id2label.get(int(preds[j]), str(int(preds[j])))
|
||||
labels_all.append(label)
|
||||
return probs_all, labels_all
|
||||
Reference in New Issue
Block a user