Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import pandas as pd | |
import os | |
import logging | |
import time | |
class ChartGenerator: | |
def __init__(self, data=None): | |
logging.info("Initializing ChartGenerator") | |
if data is not None: | |
self.data = data | |
else: | |
self.data = pd.read_excel(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'sample_data.xlsx')) | |
def generate_chart(self, plot_args): | |
start_time = time.time() | |
logging.info(f"Generating chart with arguments: {plot_args}") | |
# Validate columns before plotting | |
x_col = plot_args['x'] | |
y_cols = plot_args['y'] | |
missing_cols = [] | |
if x_col not in self.data.columns: | |
missing_cols.append(x_col) | |
for y in y_cols: | |
if y not in self.data.columns: | |
missing_cols.append(y) | |
if missing_cols: | |
logging.error(f"Missing columns in data: {missing_cols}") | |
logging.info(f"Available columns: {list(self.data.columns)}") | |
raise ValueError(f"Missing columns in data: {missing_cols}") | |
# Clear any existing plots | |
plt.clf() | |
plt.close('all') | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
for y in y_cols: | |
color = plot_args.get('color', None) | |
if plot_args.get('chart_type', 'line') == 'bar': | |
ax.bar(self.data[x_col], self.data[y], label=y, color=color) | |
else: | |
ax.plot(self.data[x_col], self.data[y], label=y, color=color, marker='o') | |
ax.set_xlabel(x_col) | |
ax.set_ylabel('Value') | |
ax.set_title(f'{plot_args.get("chart_type", "line").title()} Chart') | |
ax.legend() | |
ax.grid(True, alpha=0.3) | |
# Rotate x-axis labels if needed | |
if len(self.data[x_col]) > 5: | |
plt.xticks(rotation=45) | |
chart_filename = 'chart.png' | |
output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static', 'images') | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
logging.info(f"Created output directory: {output_dir}") | |
full_path = os.path.join(output_dir, chart_filename) | |
if os.path.exists(full_path): | |
os.remove(full_path) | |
logging.info(f"Removed existing chart file: {full_path}") | |
# Save with high DPI for better quality | |
plt.savefig(full_path, dpi=300, bbox_inches='tight', facecolor='white') | |
plt.close(fig) | |
# Verify file was created | |
if os.path.exists(full_path): | |
file_size = os.path.getsize(full_path) | |
logging.info(f"Chart generated and saved to {full_path} (size: {file_size} bytes)") | |
else: | |
logging.error(f"Failed to create chart file at {full_path}") | |
raise FileNotFoundError(f"Chart file was not created at {full_path}") | |
return os.path.join('static', 'images', chart_filename) |