import os import subprocess import traceback from datetime import datetime, timedelta import gradio as gr import numpy as np import pandas as pd import plotly.graph_objects as go import pytz from config import STATION_NAMES from supabase_utils import ( get_harmonic_predictions, save_predictions_to_supabase, get_supabase_client ) from preprocessing import preprocess_uploaded_file def get_common_args(station_id): return [ "--model", "TimeXer", "--features", "MS", "--seq_len", "144", "--pred_len", "72", "--label_len", "96", "--enc_in", "5", "--dec_in", "5", "--c_out", "1", "--d_model", "256", "--d_ff", "512", "--n_heads", "8", "--e_layers", "1", "--d_layers", "1", "--factor", "3", "--patch_len", "16", "--expand", "2", "--d_conv", "4" ] def validate_csv_file(file_path, required_rows=144): """CSV 파일 유효성 검사 - tide_level 또는 residual 지원""" try: df = pd.read_csv(file_path) # 기본 필수 컬럼 (tide_level 또는 residual 중 하나는 있어야 함) base_columns = ['date', 'air_pres', 'wind_dir', 'wind_speed', 'air_temp'] missing_base = [col for col in base_columns if col not in df.columns] if missing_base: return False, f"필수 컬럼이 누락되었습니다: {missing_base}" # tide_level 또는 residual 중 하나는 있어야 함 has_tide_level = 'tide_level' in df.columns has_residual = 'residual' in df.columns if not has_tide_level and not has_residual: return False, "tide_level 또는 residual 컬럼이 필요합니다." if len(df) < required_rows: return False, f"데이터가 부족합니다. 최소 {required_rows}행 필요, 현재 {len(df)}행" data_type = "tide_level" if has_tide_level else "residual" return True, f"파일이 유효합니다. (데이터 형태: {data_type})" except Exception as e: return False, f"파일 읽기 오류: {str(e)}" def execute_inference_and_get_results(command): """inference 실행하고 결과 파일을 읽어서 반환""" try: print(f"실행 명령어: {' '.join(command)}") result = subprocess.run(command, capture_output=True, text=True, timeout=300) if result.returncode != 0: error_message = ( f"실행 실패 (Exit Code: {result.returncode}):\n\n" f"--- 에러 로그 ---\n{result.stderr}\n\n" f"--- 일반 출력 ---\n{result.stdout}" ) raise gr.Error(error_message) return True, result.stdout except subprocess.TimeoutExpired: raise gr.Error("실행 시간이 초과되었습니다. (5분 제한)") except Exception as e: raise gr.Error(f"내부 오류: {str(e)}") def calculate_final_tide(residual_predictions, station_id, last_time): """잔차 예측 + 조화 예측 = 최종 조위 계산""" if isinstance(last_time, pd.Timestamp): last_time = last_time.to_pydatetime() kst = pytz.timezone('Asia/Seoul') if last_time.tzinfo is None: last_time = kst.localize(last_time) start_time = last_time + timedelta(minutes=5) end_time = last_time + timedelta(minutes=72*5) harmonic_data = get_harmonic_predictions(station_id, start_time, end_time) residual_flat = residual_predictions.flatten() num_points = len(residual_flat) if not harmonic_data: print("조화 예측 데이터를 찾을 수 없습니다. 잔차 예측만 반환합니다.") return { 'times': [last_time + timedelta(minutes=(i+1)*5) for i in range(num_points)], 'residual': residual_flat.tolist(), 'harmonic': [0.0] * num_points, 'final_tide': residual_flat.tolist() } final_results = { 'times': [], 'residual': [], 'harmonic': [], 'final_tide': [] } harmonic_dict = {} for h_data in harmonic_data: h_time_str = h_data['predicted_at'] try: if 'T' in h_time_str: if h_time_str.endswith('Z'): h_time = datetime.fromisoformat(h_time_str[:-1] + '+00:00') elif '+' in h_time_str or '-' in h_time_str[-6:]: h_time = datetime.fromisoformat(h_time_str) else: h_time = datetime.fromisoformat(h_time_str + '+00:00') else: from dateutil import parser h_time = parser.parse(h_time_str) if h_time.tzinfo is None: h_time = pytz.UTC.localize(h_time) h_time = h_time.astimezone(kst) except Exception as e: print(f"시간 파싱 오류: {h_time_str}, {e}") continue minutes = (h_time.minute // 5) * 5 h_time = h_time.replace(minute=minutes, second=0, microsecond=0) harmonic_value = float(h_data['harmonic_level']) harmonic_dict[h_time] = harmonic_value for i, residual in enumerate(residual_flat): pred_time = last_time + timedelta(minutes=(i+1)*5) pred_time = pred_time.replace(second=0, microsecond=0) harmonic_value = harmonic_dict.get(pred_time, 0.0) if harmonic_value == 0.0 and harmonic_dict: min_diff = float('inf') for h_time, h_val in harmonic_dict.items(): diff = abs((h_time - pred_time).total_seconds()) if diff < min_diff and diff < 300: min_diff = diff harmonic_value = h_val final_tide = float(residual) + harmonic_value final_results['times'].append(pred_time) final_results['residual'].append(float(residual)) final_results['harmonic'].append(harmonic_value) final_results['final_tide'].append(final_tide) return final_results def create_enhanced_prediction_plot(prediction_results, input_data, station_name): """잔차 + 조화 + 최종 조위를 모두 표시하는 향상된 플롯""" try: input_df = pd.read_csv(input_data.name) input_df['date'] = pd.to_datetime(input_df['date']) recent_data = input_df.tail(24) future_times = pd.to_datetime(prediction_results['times']) fig = go.Figure() fig.add_trace(go.Scatter( x=recent_data['date'], y=recent_data['residual'], mode='lines+markers', name='실제 잔차조위', line=dict(color='blue', width=2), marker=dict(size=4) )) fig.add_trace(go.Scatter( x=future_times, y=prediction_results['residual'], mode='lines+markers', name='잔차 예측', line=dict(color='red', width=2, dash='dash'), marker=dict(size=3) )) if any(h != 0 for h in prediction_results['harmonic']): fig.add_trace(go.Scatter( x=future_times, y=prediction_results['harmonic'], mode='lines', name='조화 예측', line=dict(color='orange', width=2) )) fig.add_trace(go.Scatter( x=future_times, y=prediction_results['final_tide'], mode='lines+markers', name='최종 조위', line=dict(color='green', width=3), marker=dict(size=4) )) last_time = recent_data['date'].iloc[-1] fig.add_annotation( x=last_time, y=0, text="← 과거 | 미래 →", showarrow=False, yref="paper", yshift=10, font=dict(size=12, color="gray") ) fig.update_layout( title=f'{station_name} 통합 조위 예측 결과', xaxis_title='시간', yaxis_title='수위 (cm)', hovermode='x unified', height=600, showlegend=True, xaxis=dict(tickformat='%H:%M
%m/%d', gridcolor='lightgray', showgrid=True), yaxis=dict(gridcolor='lightgray', showgrid=True), plot_bgcolor='white' ) return fig except Exception as e: print(f"Enhanced plot creation error: {e}") traceback.print_exc() fig = go.Figure() fig.add_annotation( text=f"시각화 생성 중 오류: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False ) return fig def single_prediction(station_id, input_csv_file): if input_csv_file is None: raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.") # 1. 초기 파일 검증 is_valid, message = validate_csv_file(input_csv_file.name) if not is_valid: raise gr.Error(f"파일 오류: {message}") station_name = STATION_NAMES.get(station_id, station_id) # 2. 전처리 수행 (tide_level → residual 변환 포함) gr.Info(f"📊 {station_name}({station_id}) 데이터 전처리 중...") processed_data, preprocess_result = preprocess_uploaded_file(input_csv_file.name, station_id) if processed_data is None: raise gr.Error(f"전처리 실패: {preprocess_result}") # 전처리 결과가 문자열(에러)인지 딕셔너리(성공)인지 확인 if isinstance(preprocess_result, str): raise gr.Error(f"전처리 오류: {preprocess_result}") # 전처리된 파일 경로 사용 processed_file_path = preprocess_result['output_file'] common_args = get_common_args(station_id) setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0" checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth" scaler_path = f"./checkpoints/{setting_name}/scaler.gz" if not os.path.exists(checkpoint_path): raise gr.Error(f"모델 파일을 찾을 수 없습니다: {checkpoint_path}") if not os.path.exists(scaler_path): raise gr.Error(f"스케일러 파일을 찾을 수 없습니다: {scaler_path}") # 전처리된 파일을 inference에 전달 command = ["python", "inference.py", "--checkpoint_path", checkpoint_path, "--scaler_path", scaler_path, "--predict_input_file", processed_file_path] + common_args gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...") success, output = execute_inference_and_get_results(command) try: prediction_file = "pred_results/prediction_future.npy" if os.path.exists(prediction_file): residual_predictions = np.load(prediction_file) # 전처리된 데이터 사용 input_df = processed_data last_time = input_df['date'].iloc[-1] prediction_results = calculate_final_tide(residual_predictions, station_id, last_time) # 플롯은 전처리된 데이터 파일을 사용 plot = create_enhanced_prediction_plot(prediction_results, type('obj', (object,), {'name': processed_file_path}), station_name) has_harmonic = any(h != 0 for h in prediction_results['harmonic']) if has_harmonic: result_df = pd.DataFrame({ '예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in prediction_results['times']], '잔차 예측 (cm)': [f"{val:.2f}" for val in prediction_results['residual']], '조화 예측 (cm)': [f"{val:.2f}" for val in prediction_results['harmonic']], '최종 조위 (cm)': [f"{val:.2f}" for val in prediction_results['final_tide']] }) else: result_df = pd.DataFrame({ '예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in prediction_results['times']], '잔차 예측 (cm)': [f"{val:.2f}" for val in prediction_results['residual']] }) saved_count = save_predictions_to_supabase(station_id, prediction_results) if saved_count > 0: save_message = f"\n💾 Supabase에 {saved_count}개 예측 결과 저장 완료!" elif get_supabase_client() is None: save_message = "\n⚠️ Supabase 연결 실패 (환경변수 확인 필요)" else: save_message = "\n⚠️ Supabase 저장 실패" # 전처리 정보 추가 preprocess_info = f"""📊 전처리 결과: - 원본 데이터: {preprocess_result['original_rows']}행 - 처리 데이터: {preprocess_result['processed_rows']}행 - Residual 평균: {preprocess_result['residual_mean']:.2f}cm - Residual 표준편차: {preprocess_result['residual_std']:.2f}cm""" return plot, result_df, f"✅ 예측 완료!{save_message}\n\n{preprocess_info}\n\n{output}" else: return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}" except Exception as e: print(f"Result processing error: {e}") traceback.print_exc() return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"