kjn1009 commited on
Commit
f94173d
ยท
verified ยท
1 Parent(s): 0aea0cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
8
  class MyModel(torch.nn.Module):
9
  def __init__(self, num_classes):
10
  super(MyModel, self).__init__()
11
- self.model = models.resnet18(pretrained=False) # ResNet18 ๋ชจ๋ธ ์‚ฌ์šฉ
12
  self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes) # ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด ์ˆ˜์ •
13
 
14
  def forward(self, x):
@@ -17,8 +17,9 @@ class MyModel(torch.nn.Module):
17
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
18
  num_classes = 6 # ํด๋ž˜์Šค ์ˆ˜์— ๋งž๊ฒŒ ์„ค์ •
19
  model = MyModel(num_classes)
20
- # ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
21
- model.load_state_dict(torch.load('resnet18_finetuned_2.pth', weights_only=True), strict=False)
 
22
  model.eval()
23
 
24
  # ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ์ •์˜
 
8
  class MyModel(torch.nn.Module):
9
  def __init__(self, num_classes):
10
  super(MyModel, self).__init__()
11
+ self.model = models.resnet18(weights=None) # ResNet18 ๋ชจ๋ธ ์‚ฌ์šฉ
12
  self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes) # ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด ์ˆ˜์ •
13
 
14
  def forward(self, x):
 
17
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
18
  num_classes = 6 # ํด๋ž˜์Šค ์ˆ˜์— ๋งž๊ฒŒ ์„ค์ •
19
  model = MyModel(num_classes)
20
+
21
+ # ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ (CPU๋กœ ๋งคํ•‘)
22
+ model.load_state_dict(torch.load('resnet18_finetuned_2.pth', map_location=torch.device('cpu')), strict=False)
23
  model.eval()
24
 
25
  # ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ์ •์˜